package training.feature_discovery;

import com.itextpdf.text.pdf.ColumnText;
import features.FeatureVector;
import features.feature_sets.BaseFeatureSet;
import features.spatial.FeatureUtils;
import features.spatial.SpatialFeature;
import features.spatial.instances.FeatureInstance;
import game.Game;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import main.collections.FVector;
import main.collections.FastArrayList;
import main.collections.ListUtils;
import other.move.Move;
import policies.softmax.SoftmaxPolicyLinear;
import training.ExperienceSample;
import training.expert_iteration.gradients.Gradients;
import training.expert_iteration.params.FeatureDiscoveryParams;
import training.expert_iteration.params.ObjectiveParams;
import training.feature_discovery.FeatureSetExpander;
import utils.experiments.InterruptableExperiment;

/* loaded from: input_file:training/feature_discovery/ErrorReductionExpander.class */
public class ErrorReductionExpander implements FeatureSetExpander {
    @Override // training.feature_discovery.FeatureSetExpander
    public BaseFeatureSet expandFeatureSet(List<? extends ExperienceSample> list, BaseFeatureSet baseFeatureSet, SoftmaxPolicyLinear softmaxPolicyLinear, Game game2, int i, ObjectiveParams objectiveParams, FeatureDiscoveryParams featureDiscoveryParams, TDoubleArrayList tDoubleArrayList, PrintWriter printWriter, InterruptableExperiment interruptableExperiment) {
        int i2 = 0;
        HashMap hashMap = new HashMap();
        HashSet hashSet = new HashSet((int) Math.ceil(baseFeatureSet.getNumSpatialFeatures() / 0.75f), 0.75f);
        for (SpatialFeature spatialFeature : baseFeatureSet.spatialFeatures()) {
            hashSet.add(spatialFeature);
        }
        FVector[] fVectorArr = new FVector[list.size()];
        FVector[] fVectorArr2 = new FVector[list.size()];
        final float[] fArr = new float[list.size()];
        TDoubleArrayList[] tDoubleArrayListArr = new TDoubleArrayList[baseFeatureSet.getNumSpatialFeatures()];
        TDoubleArrayList[] tDoubleArrayListArr2 = new TDoubleArrayList[baseFeatureSet.getNumSpatialFeatures()];
        for (int i3 = 0; i3 < tDoubleArrayListArr.length; i3++) {
            tDoubleArrayListArr[i3] = new TDoubleArrayList();
            tDoubleArrayListArr2[i3] = new TDoubleArrayList();
        }
        double d = 0.0d;
        for (int i4 = 0; i4 < list.size(); i4++) {
            ExperienceSample experienceSample = list.get(i4);
            FeatureVector[] generateFeatureVectors = experienceSample.generateFeatureVectors(baseFeatureSet);
            FVector computeDistribution = softmaxPolicyLinear.computeDistribution(generateFeatureVectors, experienceSample.gameState().mover());
            FVector computeDistributionErrors = Gradients.computeDistributionErrors(computeDistribution, experienceSample.expertDistribution());
            for (int i5 = 0; i5 < generateFeatureVectors.length; i5++) {
                float f = computeDistributionErrors.get(i5);
                TIntArrayList activeSpatialFeatureIndices = generateFeatureVectors[i5].activeSpatialFeatureIndices();
                activeSpatialFeatureIndices.sort();
                int i6 = 0;
                for (int i7 = 0; i7 < baseFeatureSet.getNumSpatialFeatures(); i7++) {
                    if (i6 >= activeSpatialFeatureIndices.size() || activeSpatialFeatureIndices.getQuick(i6) != i7) {
                        tDoubleArrayListArr2[i7].add(f);
                    } else {
                        tDoubleArrayListArr[i7].add(f);
                        i6++;
                    }
                }
                d += (f - d) / (i2 + 1);
                i2++;
            }
            FVector copy = computeDistributionErrors.copy();
            copy.abs();
            fVectorArr[i4] = computeDistribution;
            fVectorArr2[i4] = computeDistributionErrors;
            fArr[i4] = copy.sum();
        }
        double[] dArr = new double[baseFeatureSet.getNumSpatialFeatures()];
        double[] dArr2 = new double[baseFeatureSet.getNumSpatialFeatures()];
        double[] dArr3 = new double[baseFeatureSet.getNumSpatialFeatures()];
        for (int i8 = 0; i8 < baseFeatureSet.getNumSpatialFeatures(); i8++) {
            TDoubleArrayList tDoubleArrayList2 = tDoubleArrayListArr[i8];
            TDoubleArrayList tDoubleArrayList3 = tDoubleArrayListArr2[i8];
            double size = tDoubleArrayList2.size() / (tDoubleArrayList2.size() + tDoubleArrayListArr2[i8].size());
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i9 = 0; i9 < tDoubleArrayList2.size(); i9++) {
                double quick = tDoubleArrayList2.getQuick(i9);
                double d4 = quick - d;
                d3 += (1.0d - size) * d4;
                d2 += d4 * d4;
                int i10 = i8;
                dArr2[i10] = dArr2[i10] + ((Math.abs(quick) - dArr2[i8]) / (i9 + 1));
                int i11 = i8;
                dArr3[i11] = dArr3[i11] + ((Math.abs(quick) - dArr3[i8]) / (i9 + 1));
            }
            for (int i12 = 0; i12 < tDoubleArrayList3.size(); i12++) {
                double quick2 = tDoubleArrayList3.getQuick(i12) - d;
                d3 += (0.0d - size) * quick2;
                d2 += quick2 * quick2;
                int i13 = i8;
                dArr3[i13] = dArr3[i13] + ((0.0d - dArr3[i8]) / (i12 + 1));
            }
            dArr[i8] = d3 / Math.sqrt(((tDoubleArrayList2.size() * ((1.0d - size) * (1.0d - size))) + (tDoubleArrayList3.size() * ((0.0d - size) * (0.0d - size)))) * d2);
            if (Double.isNaN(dArr[i8])) {
                dArr[i8] = 0.0d;
            }
        }
        ArrayList arrayList = new ArrayList(list.size());
        for (int i14 = 0; i14 < list.size(); i14++) {
            arrayList.add(Integer.valueOf(i14));
        }
        Collections.sort(arrayList, new Comparator<Integer>() { // from class: training.feature_discovery.ErrorReductionExpander.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                float f2 = fArr[num.intValue()] - fArr[num2.intValue()];
                if (f2 > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                    return -1;
                }
                return f2 < ColumnText.GLOBAL_SPACE_CHAR_RATIO ? 1 : 0;
            }
        });
        HashSet hashSet2 = new HashSet();
        HashSet hashSet3 = new HashSet();
        for (int i15 = 0; i15 < arrayList.size(); i15++) {
            int intValue = ((Integer) arrayList.get(i15)).intValue();
            ExperienceSample experienceSample2 = list.get(intValue);
            FVector fVector = fVectorArr2[intValue];
            float min = fVector.min();
            float max = fVector.max();
            FastArrayList<Move> moves = experienceSample2.moves();
            TIntArrayList tIntArrayList = new TIntArrayList();
            BitSet winningMoves = experienceSample2.winningMoves();
            int nextSetBit = winningMoves.nextSetBit(0);
            while (true) {
                int i16 = nextSetBit;
                if (i16 < 0) {
                    break;
                }
                tIntArrayList.add(i16);
                nextSetBit = winningMoves.nextSetBit(i16 + 1);
            }
            BitSet losingMoves = experienceSample2.losingMoves();
            int nextSetBit2 = losingMoves.nextSetBit(0);
            while (true) {
                int i17 = nextSetBit2;
                if (i17 < 0) {
                    break;
                }
                tIntArrayList.add(i17);
                nextSetBit2 = losingMoves.nextSetBit(i17 + 1);
            }
            BitSet antiDefeatingMoves = experienceSample2.antiDefeatingMoves();
            int nextSetBit3 = antiDefeatingMoves.nextSetBit(0);
            while (true) {
                int i18 = nextSetBit3;
                if (i18 < 0) {
                    break;
                }
                tIntArrayList.add(i18);
                nextSetBit3 = antiDefeatingMoves.nextSetBit(i18 + 1);
            }
            TIntArrayList tIntArrayList2 = new TIntArrayList();
            for (int i19 = 0; i19 < moves.size(); i19++) {
                if (!winningMoves.get(i19) && !losingMoves.get(i19) && !antiDefeatingMoves.get(i19)) {
                    tIntArrayList2.add(i19);
                }
            }
            while (!tIntArrayList2.isEmpty()) {
                int nextInt = ThreadLocalRandom.current().nextInt(tIntArrayList2.size());
                int quick3 = tIntArrayList2.getQuick(nextInt);
                ListUtils.removeSwap(tIntArrayList2, nextInt);
                tIntArrayList.add(quick3);
            }
            for (int i20 = 0; i20 < tIntArrayList.size(); i20++) {
                int quick4 = tIntArrayList.getQuick(i20);
                HashSet hashSet4 = new HashSet(256, 0.75f);
                ArrayList arrayList2 = new ArrayList(new HashSet(baseFeatureSet.getActiveSpatialFeatureInstances(experienceSample2.gameState(), experienceSample2.lastFromPos(), experienceSample2.lastToPos(), FeatureUtils.fromPos(moves.get(quick4)), FeatureUtils.toPos(moves.get(quick4)), moves.get(quick4).mover())));
                ArrayList arrayList3 = new ArrayList(arrayList2);
                ArrayList arrayList4 = new ArrayList();
                ArrayList arrayList5 = new ArrayList();
                ArrayList arrayList6 = new ArrayList();
                int i21 = 0;
                while (i21 < arrayList2.size()) {
                    FeatureInstance featureInstance = (FeatureInstance) arrayList2.get(i21);
                    FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair = new FeatureSetExpander.CombinableFeatureInstancePair(game2, featureInstance, featureInstance);
                    if (hashSet2.contains(combinableFeatureInstancePair)) {
                        arrayList6.add(combinableFeatureInstancePair);
                        arrayList4.add(featureInstance);
                        arrayList2.remove(i21);
                    } else if (hashSet3.contains(combinableFeatureInstancePair)) {
                        arrayList2.remove(i21);
                    } else {
                        arrayList5.add(combinableFeatureInstancePair);
                        i21++;
                    }
                }
                int min2 = Math.min(Math.min(15, i - hashSet2.size()), arrayList2.size());
                if (min2 > 0) {
                    FVector fVector2 = new FVector(arrayList2.size());
                    for (int i22 = 0; i22 < arrayList2.size(); i22++) {
                        int spatialFeatureSetIndex = ((FeatureInstance) arrayList2.get(i22)).feature().spatialFeatureSetIndex();
                        fVector2.set(i22, (float) (dArr[spatialFeatureSetIndex] + dArr2[spatialFeatureSetIndex] + dArr3[spatialFeatureSetIndex]));
                    }
                    fVector2.softmax(1.0d);
                    for (int i23 = 0; i23 < arrayList2.size(); i23++) {
                        int spatialFeatureSetIndex2 = ((FeatureInstance) arrayList2.get(i23)).feature().spatialFeatureSetIndex();
                        int i24 = 0;
                        for (int i25 = 0; i25 < arrayList3.size(); i25++) {
                            if (((FeatureInstance) arrayList3.get(i25)).feature().spatialFeatureSetIndex() == spatialFeatureSetIndex2) {
                                i24++;
                            }
                        }
                        fVector2.set(i23, fVector2.get(i23) / i24);
                    }
                    fVector2.normalise();
                    while (min2 > 0) {
                        int sampleFromDistribution = fVector2.sampleFromDistribution();
                        FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair2 = (FeatureSetExpander.CombinableFeatureInstancePair) arrayList5.get(sampleFromDistribution);
                        arrayList4.add((FeatureInstance) arrayList2.get(sampleFromDistribution));
                        arrayList6.add(combinableFeatureInstancePair2);
                        hashSet2.add(combinableFeatureInstancePair2);
                        fVector2.updateSoftmaxInvalidate(sampleFromDistribution);
                        min2--;
                    }
                }
                for (int i26 = 0; i26 < arrayList2.size(); i26++) {
                    FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair3 = new FeatureSetExpander.CombinableFeatureInstancePair(game2, (FeatureInstance) arrayList2.get(i26), (FeatureInstance) arrayList2.get(i26));
                    if (!hashSet2.contains(combinableFeatureInstancePair3)) {
                        hashSet3.add(combinableFeatureInstancePair3);
                    }
                }
                int size2 = arrayList4.size();
                float f2 = fVector.get(quick4);
                if (winningMoves.get(quick4)) {
                    f2 = min;
                } else if (losingMoves.get(quick4)) {
                    f2 = max;
                } else if (antiDefeatingMoves.get(quick4)) {
                    f2 = Math.min(f2, min + 0.1f);
                }
                for (int i27 = 0; i27 < size2; i27++) {
                    FeatureInstance featureInstance2 = (FeatureInstance) arrayList4.get(i27);
                    FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair4 = (FeatureSetExpander.CombinableFeatureInstancePair) arrayList6.get(i27);
                    if (hashSet4.add(combinableFeatureInstancePair4)) {
                        TDoubleArrayList tDoubleArrayList4 = (TDoubleArrayList) hashMap.get(combinableFeatureInstancePair4);
                        if (tDoubleArrayList4 == null) {
                            tDoubleArrayList4 = new TDoubleArrayList();
                            hashMap.put(combinableFeatureInstancePair4, tDoubleArrayList4);
                        }
                        tDoubleArrayList4.add(f2);
                    }
                    for (int i28 = i27 + 1; i28 < size2; i28++) {
                        FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair5 = new FeatureSetExpander.CombinableFeatureInstancePair(game2, featureInstance2, (FeatureInstance) arrayList4.get(i28));
                        if (!hashSet.contains(combinableFeatureInstancePair5.combinedFeature)) {
                            TDoubleArrayList tDoubleArrayList5 = (TDoubleArrayList) hashMap.get(combinableFeatureInstancePair5);
                            if (tDoubleArrayList5 == null) {
                                tDoubleArrayList5 = new TDoubleArrayList();
                                hashMap.put(combinableFeatureInstancePair5, tDoubleArrayList5);
                            }
                            tDoubleArrayList5.add(f2);
                        }
                    }
                }
            }
        }
        Comparator<FeatureSetExpander.ScoredFeatureInstancePair> comparator = new Comparator<FeatureSetExpander.ScoredFeatureInstancePair>() { // from class: training.feature_discovery.ErrorReductionExpander.2
            @Override // java.util.Comparator
            public int compare(FeatureSetExpander.ScoredFeatureInstancePair scoredFeatureInstancePair, FeatureSetExpander.ScoredFeatureInstancePair scoredFeatureInstancePair2) {
                if (scoredFeatureInstancePair.score < scoredFeatureInstancePair2.score) {
                    return 1;
                }
                return scoredFeatureInstancePair.score > scoredFeatureInstancePair2.score ? -1 : 0;
            }
        };
        PriorityQueue priorityQueue = new PriorityQueue(comparator);
        PriorityQueue priorityQueue2 = new PriorityQueue(comparator);
        for (FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair6 : hashMap.keySet()) {
            if (!combinableFeatureInstancePair6.a.equals(combinableFeatureInstancePair6.b)) {
                double computeMaxErrorReduction = (computeMaxErrorReduction((TDoubleArrayList) hashMap.get(combinableFeatureInstancePair6)) - computeMaxErrorReduction((TDoubleArrayList) hashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, combinableFeatureInstancePair6.a, combinableFeatureInstancePair6.a)))) - computeMaxErrorReduction((TDoubleArrayList) hashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, combinableFeatureInstancePair6.b, combinableFeatureInstancePair6.b)));
                if (combinableFeatureInstancePair6.combinedFeature.isReactive()) {
                    priorityQueue2.add(new FeatureSetExpander.ScoredFeatureInstancePair(combinableFeatureInstancePair6, computeMaxErrorReduction));
                } else {
                    priorityQueue.add(new FeatureSetExpander.ScoredFeatureInstancePair(combinableFeatureInstancePair6, computeMaxErrorReduction));
                }
            }
        }
        BaseFeatureSet baseFeatureSet2 = baseFeatureSet;
        while (true) {
            if (priorityQueue.isEmpty()) {
                break;
            }
            BaseFeatureSet createExpandedFeatureSet = baseFeatureSet2.createExpandedFeatureSet(game2, ((FeatureSetExpander.ScoredFeatureInstancePair) priorityQueue.poll()).pair.combinedFeature);
            if (createExpandedFeatureSet != null) {
                baseFeatureSet2 = createExpandedFeatureSet;
                break;
            }
        }
        while (true) {
            if (priorityQueue2.isEmpty()) {
                break;
            }
            BaseFeatureSet createExpandedFeatureSet2 = baseFeatureSet2.createExpandedFeatureSet(game2, ((FeatureSetExpander.ScoredFeatureInstancePair) priorityQueue2.poll()).pair.combinedFeature);
            if (createExpandedFeatureSet2 != null) {
                baseFeatureSet2 = createExpandedFeatureSet2;
                break;
            }
        }
        return baseFeatureSet2;
    }

    private static final double computeMaxErrorReduction(TDoubleArrayList tDoubleArrayList) {
        tDoubleArrayList.sort();
        int size = (tDoubleArrayList.size() - 1) / 2;
        double quick = tDoubleArrayList.size() % 2 == 0 ? (tDoubleArrayList.getQuick(size) + tDoubleArrayList.getQuick(size + 1)) / 2.0d : tDoubleArrayList.getQuick(size);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < tDoubleArrayList.size(); i++) {
            d += Math.abs(tDoubleArrayList.getQuick(i));
            d2 += Math.abs(tDoubleArrayList.getQuick(i) - quick);
        }
        double d3 = d - d2;
        if (d3 < 0.0d) {
            System.err.println("ERROR: NEGATIVE ERROR REDUCTION!");
        }
        return d3;
    }
}
