package training.feature_discovery;

import com.itextpdf.text.pdf.ColumnText;
import features.FeatureVector;
import features.feature_sets.BaseFeatureSet;
import features.spatial.FeatureUtils;
import features.spatial.SpatialFeature;
import features.spatial.instances.FeatureInstance;
import game.Game;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import main.collections.FVector;
import main.collections.FastArrayList;
import other.move.Move;
import policies.softmax.SoftmaxPolicyLinear;
import training.ExperienceSample;
import training.expert_iteration.gradients.Gradients;
import training.expert_iteration.params.FeatureDiscoveryParams;
import training.expert_iteration.params.ObjectiveParams;
import training.feature_discovery.FeatureSetExpander;
import utils.experiments.InterruptableExperiment;

/* loaded from: input_file:training/feature_discovery/VarianceReductionExpander.class */
public class VarianceReductionExpander implements FeatureSetExpander {
    @Override // training.feature_discovery.FeatureSetExpander
    public BaseFeatureSet expandFeatureSet(List<? extends ExperienceSample> list, BaseFeatureSet baseFeatureSet, SoftmaxPolicyLinear softmaxPolicyLinear, Game game2, int i, ObjectiveParams objectiveParams, FeatureDiscoveryParams featureDiscoveryParams, TDoubleArrayList tDoubleArrayList, PrintWriter printWriter, InterruptableExperiment interruptableExperiment) {
        int i2;
        TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap(10, 0.5f, 0);
        TObjectDoubleHashMap tObjectDoubleHashMap = new TObjectDoubleHashMap(10, 0.5f, 0.0d);
        TObjectDoubleHashMap tObjectDoubleHashMap2 = new TObjectDoubleHashMap(10, 0.5f, 0.0d);
        TObjectDoubleHashMap tObjectDoubleHashMap3 = new TObjectDoubleHashMap(10, 0.5f, 0.0d);
        int i3 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        HashSet hashSet = new HashSet((int) Math.ceil(baseFeatureSet.getNumSpatialFeatures() / 0.75f), 0.75f);
        for (SpatialFeature spatialFeature : baseFeatureSet.spatialFeatures()) {
            hashSet.add(spatialFeature);
        }
        HashSet hashSet2 = new HashSet();
        HashSet hashSet3 = new HashSet();
        FVector[] fVectorArr = new FVector[list.size()];
        FVector[] fVectorArr2 = new FVector[list.size()];
        final float[] fArr = new float[list.size()];
        TDoubleArrayList[] tDoubleArrayListArr = new TDoubleArrayList[baseFeatureSet.getNumSpatialFeatures()];
        TDoubleArrayList[] tDoubleArrayListArr2 = new TDoubleArrayList[baseFeatureSet.getNumSpatialFeatures()];
        for (int i4 = 0; i4 < tDoubleArrayListArr.length; i4++) {
            tDoubleArrayListArr[i4] = new TDoubleArrayList();
            tDoubleArrayListArr2[i4] = new TDoubleArrayList();
        }
        double d4 = 0.0d;
        int i5 = 0;
        for (int i6 = 0; i6 < list.size(); i6++) {
            ExperienceSample experienceSample = list.get(i6);
            FeatureVector[] generateFeatureVectors = experienceSample.generateFeatureVectors(baseFeatureSet);
            FVector computeDistribution = softmaxPolicyLinear.computeDistribution(generateFeatureVectors, experienceSample.gameState().mover());
            FVector computeDistributionErrors = Gradients.computeDistributionErrors(computeDistribution, experienceSample.expertDistribution());
            for (int i7 = 0; i7 < generateFeatureVectors.length; i7++) {
                float f = computeDistributionErrors.get(i7);
                TIntArrayList activeSpatialFeatureIndices = generateFeatureVectors[i7].activeSpatialFeatureIndices();
                activeSpatialFeatureIndices.sort();
                int i8 = 0;
                for (int i9 = 0; i9 < baseFeatureSet.getNumSpatialFeatures(); i9++) {
                    if (i8 >= activeSpatialFeatureIndices.size() || activeSpatialFeatureIndices.getQuick(i8) != i9) {
                        tDoubleArrayListArr2[i9].add(f);
                    } else {
                        tDoubleArrayListArr[i9].add(f);
                        i8++;
                    }
                }
                d4 += (f - d4) / (i5 + 1);
                i5++;
            }
            FVector copy = computeDistributionErrors.copy();
            copy.abs();
            fVectorArr[i6] = computeDistribution;
            fVectorArr2[i6] = computeDistributionErrors;
            fArr[i6] = copy.sum();
        }
        double[] dArr = new double[baseFeatureSet.getNumSpatialFeatures()];
        double[] dArr2 = new double[baseFeatureSet.getNumSpatialFeatures()];
        for (int i10 = 0; i10 < baseFeatureSet.getNumSpatialFeatures(); i10++) {
            TDoubleArrayList tDoubleArrayList2 = tDoubleArrayListArr[i10];
            TDoubleArrayList tDoubleArrayList3 = tDoubleArrayListArr2[i10];
            double size = tDoubleArrayList2.size() / (tDoubleArrayList2.size() + tDoubleArrayListArr2[i10].size());
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i11 = 0; i11 < tDoubleArrayList2.size(); i11++) {
                double quick = tDoubleArrayList2.getQuick(i11);
                double d7 = quick - d4;
                d6 += (1.0d - size) * d7;
                d5 += d7 * d7;
                int i12 = i10;
                dArr2[i12] = dArr2[i12] + ((Math.abs(quick) - dArr2[i10]) / (i11 + 1));
            }
            for (int i13 = 0; i13 < tDoubleArrayList3.size(); i13++) {
                double quick2 = tDoubleArrayList3.getQuick(i13) - d4;
                d6 += (0.0d - size) * quick2;
                d5 += quick2 * quick2;
            }
            dArr[i10] = d6 / Math.sqrt(((tDoubleArrayList2.size() * ((1.0d - size) * (1.0d - size))) + (tDoubleArrayList3.size() * ((0.0d - size) * (0.0d - size)))) * d5);
        }
        ArrayList arrayList = new ArrayList(list.size());
        for (int i14 = 0; i14 < list.size(); i14++) {
            arrayList.add(Integer.valueOf(i14));
        }
        Collections.sort(arrayList, new Comparator<Integer>() { // from class: training.feature_discovery.VarianceReductionExpander.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                float f2 = fArr[num.intValue()] - fArr[num2.intValue()];
                if (f2 > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                    return -1;
                }
                return f2 < ColumnText.GLOBAL_SPACE_CHAR_RATIO ? 1 : 0;
            }
        });
        for (int i15 = 0; i15 < arrayList.size(); i15++) {
            int intValue = ((Integer) arrayList.get(i15)).intValue();
            ExperienceSample experienceSample2 = list.get(intValue);
            FVector fVector = fVectorArr2[intValue];
            FastArrayList<Move> moves = experienceSample2.moves();
            for (int i16 = 0; i16 < moves.size(); i16++) {
                i3++;
                HashSet hashSet4 = new HashSet(256, 0.75f);
                ArrayList<FeatureInstance> arrayList2 = new ArrayList(new HashSet(baseFeatureSet.getActiveSpatialFeatureInstances(experienceSample2.gameState(), experienceSample2.lastFromPos(), experienceSample2.lastToPos(), FeatureUtils.fromPos(moves.get(i16)), FeatureUtils.toPos(moves.get(i16)), moves.get(i16).mover())));
                ArrayList arrayList3 = new ArrayList();
                int i17 = 0;
                while (i17 < arrayList2.size()) {
                    FeatureInstance featureInstance = (FeatureInstance) arrayList2.get(i17);
                    FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair = new FeatureSetExpander.CombinableFeatureInstancePair(game2, featureInstance, featureInstance);
                    if (hashSet2.contains(combinableFeatureInstancePair)) {
                        arrayList3.add(featureInstance);
                        arrayList2.remove(i17);
                    } else if (hashSet3.contains(combinableFeatureInstancePair)) {
                        arrayList2.remove(i17);
                    } else {
                        i17++;
                    }
                }
                FVector fVector2 = new FVector(arrayList2.size());
                for (int i18 = 0; i18 < arrayList2.size(); i18++) {
                    int spatialFeatureSetIndex = ((FeatureInstance) arrayList2.get(i18)).feature().spatialFeatureSetIndex();
                    fVector2.set(i18, (float) (dArr[spatialFeatureSetIndex] + dArr2[spatialFeatureSetIndex]));
                }
                fVector2.softmax(2.0d);
                for (int min = Math.min(Math.min(Math.max(5, i - (hashSet2.size() / (moves.size() - i16))), i - hashSet2.size()), arrayList2.size()); min > 0; min--) {
                    int sampleFromDistribution = fVector2.sampleFromDistribution();
                    FeatureInstance featureInstance2 = (FeatureInstance) arrayList2.get(sampleFromDistribution);
                    arrayList3.add(featureInstance2);
                    hashSet2.add(new FeatureSetExpander.CombinableFeatureInstancePair(game2, featureInstance2, featureInstance2));
                    fVector2.updateSoftmaxInvalidate(sampleFromDistribution);
                }
                for (FeatureInstance featureInstance3 : arrayList2) {
                    FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair2 = new FeatureSetExpander.CombinableFeatureInstancePair(game2, featureInstance3, featureInstance3);
                    if (!hashSet2.contains(combinableFeatureInstancePair2)) {
                        hashSet3.add(combinableFeatureInstancePair2);
                    }
                }
                int size2 = arrayList3.size();
                float f2 = fVector.get(i16);
                d += f2;
                d2 += f2 * f2;
                d3 += (f2 - d3) / i3;
                for (int i19 = 0; i19 < size2; i19++) {
                    FeatureInstance featureInstance4 = (FeatureInstance) arrayList3.get(i19);
                    FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair3 = new FeatureSetExpander.CombinableFeatureInstancePair(game2, featureInstance4, featureInstance4);
                    if (hashSet4.add(combinableFeatureInstancePair3)) {
                        tObjectIntHashMap.adjustOrPutValue(combinableFeatureInstancePair3, 1, 1);
                        tObjectDoubleHashMap.adjustOrPutValue(combinableFeatureInstancePair3, f2, f2);
                        tObjectDoubleHashMap2.adjustOrPutValue(combinableFeatureInstancePair3, f2 * f2, f2 * f2);
                        double d8 = (f2 - tObjectDoubleHashMap3.get(combinableFeatureInstancePair3)) / tObjectIntHashMap.get(combinableFeatureInstancePair3);
                        tObjectDoubleHashMap3.adjustOrPutValue(combinableFeatureInstancePair3, d8, d8);
                    }
                    for (int i20 = i19 + 1; i20 < size2; i20++) {
                        FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair4 = new FeatureSetExpander.CombinableFeatureInstancePair(game2, featureInstance4, (FeatureInstance) arrayList3.get(i20));
                        if (!hashSet.contains(combinableFeatureInstancePair4.combinedFeature) && hashSet4.add(combinableFeatureInstancePair4)) {
                            tObjectIntHashMap.adjustOrPutValue(combinableFeatureInstancePair4, 1, 1);
                            tObjectDoubleHashMap.adjustOrPutValue(combinableFeatureInstancePair4, f2, f2);
                            tObjectDoubleHashMap2.adjustOrPutValue(combinableFeatureInstancePair4, f2 * f2, f2 * f2);
                            double d9 = (f2 - tObjectDoubleHashMap3.get(combinableFeatureInstancePair4)) / tObjectIntHashMap.get(combinableFeatureInstancePair4);
                            tObjectDoubleHashMap3.adjustOrPutValue(combinableFeatureInstancePair4, d9, d9);
                        }
                    }
                }
            }
        }
        if (d == 0.0d || d2 == 0.0d) {
            return null;
        }
        ArrayList arrayList4 = new ArrayList(tObjectIntHashMap.size());
        double d10 = Double.NEGATIVE_INFINITY;
        int i21 = -1;
        for (FeatureSetExpander.CombinableFeatureInstancePair combinableFeatureInstancePair5 : tObjectIntHashMap.keySet()) {
            if (!combinableFeatureInstancePair5.a.equals(combinableFeatureInstancePair5.b) && (i2 = tObjectIntHashMap.get(combinableFeatureInstancePair5)) != i3 && i2 >= 2) {
                int i22 = tObjectIntHashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, combinableFeatureInstancePair5.a, combinableFeatureInstancePair5.a));
                int i23 = tObjectIntHashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, combinableFeatureInstancePair5.b, combinableFeatureInstancePair5.b));
                if (i22 != i3 && i23 != i3) {
                    double d11 = tObjectDoubleHashMap.get(combinableFeatureInstancePair5);
                    double d12 = tObjectDoubleHashMap2.get(combinableFeatureInstancePair5);
                    double d13 = tObjectDoubleHashMap3.get(combinableFeatureInstancePair5);
                    double d14 = ((d12 - ((2.0d * d13) * d11)) + ((i2 * d13) * d13)) / (i2 - 1);
                    int i24 = i3 - i2;
                    double d15 = d - d11;
                    double d16 = d2 - d12;
                    double d17 = ((d3 * i3) - (d13 * i2)) / i24;
                    double d18 = (((d2 - ((2.0d * d3) * d)) + ((i3 * d3) * d3)) / (i3 - 1)) - (d14 + (((d16 - ((2.0d * d17) * d15)) + ((i24 * d17) * d17)) / (i2 - 1)));
                    if (Math.max(Math.abs((i2 * (i3 - i22)) / (Math.sqrt(i2 * (i3 - i2)) * Math.sqrt(i22 * (i3 - i22)))), Math.abs((i2 * (i3 - i23)) / (Math.sqrt(i2 * (i3 - i2)) * Math.sqrt(i23 * (i3 - i23))))) != 1.0d && !Double.isNaN(d18)) {
                        arrayList4.add(new FeatureSetExpander.ScoredFeatureInstancePair(combinableFeatureInstancePair5, d18));
                        if (d18 > d10) {
                            d10 = d18;
                            i21 = arrayList4.size() - 1;
                        }
                    }
                }
            }
        }
        while (arrayList4.size() > 0) {
            FeatureSetExpander.ScoredFeatureInstancePair scoredFeatureInstancePair = (FeatureSetExpander.ScoredFeatureInstancePair) arrayList4.remove(i21);
            BaseFeatureSet createExpandedFeatureSet = baseFeatureSet.createExpandedFeatureSet(game2, scoredFeatureInstancePair.pair.combinedFeature);
            if (createExpandedFeatureSet != null) {
                int i25 = tObjectIntHashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, scoredFeatureInstancePair.pair.a, scoredFeatureInstancePair.pair.a));
                int i26 = tObjectIntHashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, scoredFeatureInstancePair.pair.b, scoredFeatureInstancePair.pair.b));
                int i27 = tObjectIntHashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, scoredFeatureInstancePair.pair.a, scoredFeatureInstancePair.pair.b));
                double sqrt = ((i3 * tObjectDoubleHashMap.get(new FeatureSetExpander.CombinableFeatureInstancePair(game2, scoredFeatureInstancePair.pair.a, scoredFeatureInstancePair.pair.b))) - (i27 * d)) / (Math.sqrt((i3 * i27) - (i27 * i27)) * Math.sqrt((i3 * d2) - (d * d)));
                double sqrt2 = ((i3 * i27) - (i27 * i25)) / (Math.sqrt((i3 * i27) - (i27 * i27)) * Math.sqrt((i3 * i25) - (i25 * i25)));
                double sqrt3 = ((i3 * i27) - (i27 * i26)) / (Math.sqrt((i3 * i27) - (i27 * i27)) * Math.sqrt((i3 * i26) - (i26 * i26)));
                double d19 = ((d2 - ((2.0d * d3) * d)) + ((i3 * d3) * d3)) / (i3 - 1);
                double d20 = tObjectDoubleHashMap.get(scoredFeatureInstancePair.pair);
                double d21 = tObjectDoubleHashMap2.get(scoredFeatureInstancePair.pair);
                double d22 = tObjectDoubleHashMap3.get(scoredFeatureInstancePair.pair);
                double d23 = ((d21 - ((2.0d * d22) * d20)) + ((i27 * d22) * d22)) / (i27 - 1);
                int i28 = i3 - i27;
                double d24 = d - d20;
                double d25 = d2 - d21;
                double d26 = ((d3 * i3) - (d22 * i27)) / i28;
                double d27 = ((d25 - ((2.0d * d26) * d24)) + ((i28 * d26) * d26)) / (i27 - 1);
                interruptableExperiment.logLine(printWriter, "New feature added!");
                interruptableExperiment.logLine(printWriter, "new feature = " + createExpandedFeatureSet.spatialFeatures()[createExpandedFeatureSet.getNumSpatialFeatures() - 1]);
                interruptableExperiment.logLine(printWriter, "active feature A = " + scoredFeatureInstancePair.pair.a.feature());
                interruptableExperiment.logLine(printWriter, "rot A = " + scoredFeatureInstancePair.pair.a.rotation());
                interruptableExperiment.logLine(printWriter, "ref A = " + scoredFeatureInstancePair.pair.a.reflection());
                interruptableExperiment.logLine(printWriter, "anchor A = " + scoredFeatureInstancePair.pair.a.anchorSite());
                interruptableExperiment.logLine(printWriter, "active feature B = " + scoredFeatureInstancePair.pair.b.feature());
                interruptableExperiment.logLine(printWriter, "rot B = " + scoredFeatureInstancePair.pair.b.rotation());
                interruptableExperiment.logLine(printWriter, "ref B = " + scoredFeatureInstancePair.pair.b.reflection());
                interruptableExperiment.logLine(printWriter, "anchor B = " + scoredFeatureInstancePair.pair.b.anchorSite());
                interruptableExperiment.logLine(printWriter, "score = " + scoredFeatureInstancePair.score);
                interruptableExperiment.logLine(printWriter, "correlation with errors = " + sqrt);
                interruptableExperiment.logLine(printWriter, "correlation with first constituent = " + sqrt2);
                interruptableExperiment.logLine(printWriter, "correlation with second constituent = " + sqrt3);
                interruptableExperiment.logLine(printWriter, "varComplete = " + d19);
                interruptableExperiment.logLine(printWriter, "varWithPair = " + d23);
                interruptableExperiment.logLine(printWriter, "varWithoutPair = " + d27);
                interruptableExperiment.logLine(printWriter, "varReduction = " + (d19 - (d23 + d27)));
                return createExpandedFeatureSet;
            }
            double d28 = Double.NEGATIVE_INFINITY;
            i21 = -1;
            for (int i29 = 0; i29 < arrayList4.size(); i29++) {
                if (((FeatureSetExpander.ScoredFeatureInstancePair) arrayList4.get(i29)).score > d28) {
                    d28 = ((FeatureSetExpander.ScoredFeatureInstancePair) arrayList4.get(i29)).score;
                    i21 = i29;
                }
            }
        }
        return null;
    }
}
