package supplementary.experiments.feature_importance;

import com.itextpdf.text.pdf.ColumnText;
import com.itextpdf.text.xml.xmp.XmpWriter;
import features.Feature;
import features.FeatureVector;
import features.WeightVector;
import features.feature_sets.BaseFeatureSet;
import function_approx.LinearFunction;
import game.Game;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import main.CommandLineArgParse;
import main.StringRoutines;
import main.collections.ArrayUtils;
import main.collections.FVector;
import org.apache.batik.constants.XMLConstants;
import org.apache.batik.svggen.SVGSyntax;
import other.GameLoader;
import policies.softmax.SoftmaxPolicyLinear;
import search.mcts.MCTS;
import training.expert_iteration.ExItExperience;
import utils.AIFactory;
import utils.ExperimentFileUtils;
import utils.data_structures.experience_buffers.ExperienceBuffer;
import utils.data_structures.experience_buffers.PrioritizedReplayBuffer;
import utils.data_structures.experience_buffers.UniformExperienceBuffer;

/* loaded from: input_file:supplementary/experiments/feature_importance/AnalyseFeatureImportances.class */
public class AnalyseFeatureImportances {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:supplementary/experiments/feature_importance/AnalyseFeatureImportances$Row.class */
    public static class Row {
        public final Feature feature;
        public double sse;
        public double reductionSSE;
        public double sseFalse;
        public double sseTrue;
        public int sampleSizeFalse;
        public int sampleSizeTrue;
        public double meanTargetFalse;
        public double meanTargetTrue;
        public double urgency;
        public double weightedUrgency;
        public double urgencyRatio;
        public double scaledUrgency;

        public Row(Feature feature) {
            this.feature = feature;
        }

        public String toString() {
            return StringRoutines.join(SVGSyntax.COMMA, StringRoutines.quote(this.feature.toString()), Double.valueOf(this.sse), Double.valueOf(this.reductionSSE), Double.valueOf(this.sseFalse), Double.valueOf(this.sseTrue), Double.valueOf(this.sampleSizeFalse), Double.valueOf(this.sampleSizeTrue), Double.valueOf(this.meanTargetFalse), Double.valueOf(this.meanTargetTrue), Double.valueOf(this.urgency), Double.valueOf(this.weightedUrgency), Double.valueOf(this.urgencyRatio), Double.valueOf(this.scaledUrgency));
        }
    }

    private static void analyseFeatureImportances(CommandLineArgParse commandLineArgParse) {
        double d;
        double d2;
        double d3;
        double d4;
        double d5;
        double d6;
        String valueString = commandLineArgParse.getValueString("--game-name");
        String valueString2 = commandLineArgParse.getValueString("--game-training-dir");
        if (!valueString2.endsWith("/")) {
            valueString2 = valueString2 + "/";
        }
        Game loadGameFromName = GameLoader.loadGameFromName(valueString);
        StringBuilder sb = new StringBuilder();
        sb.append("playout=softmax");
        for (int i = 1; i <= loadGameFromName.players().count(); i++) {
            sb.append(",policyweights" + i + XMLConstants.XML_EQUAL_SIGN + ExperimentFileUtils.getLastFilepath(valueString2 + "PolicyWeightsTSPG_P" + i, "txt"));
        }
        sb.append(",boosted=true");
        SoftmaxPolicyLinear softmaxPolicyLinear = (SoftmaxPolicyLinear) ((MCTS) AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", sb.toString(), "final_move=robustchild", "tree_reuse=true", "learned_selection_policy=playout", "friendly_name=BiasedMCTS"))).playoutStrategy();
        BaseFeatureSet[] featureSets = softmaxPolicyLinear.featureSets();
        LinearFunction[] linearFunctions = softmaxPolicyLinear.linearFunctions();
        softmaxPolicyLinear.initAI(loadGameFromName, -1);
        String lastFilepath = ExperimentFileUtils.getLastFilepath(valueString2 + "ExperienceBuffer_P1", "buf");
        ExperienceBuffer experienceBuffer = null;
        try {
            experienceBuffer = PrioritizedReplayBuffer.fromFile(loadGameFromName, lastFilepath);
        } catch (Exception e) {
            if (experienceBuffer == null) {
                try {
                    experienceBuffer = UniformExperienceBuffer.fromFile(loadGameFromName, lastFilepath);
                } catch (Exception e2) {
                    e.printStackTrace();
                    e2.printStackTrace();
                    return;
                }
            }
        }
        BaseFeatureSet baseFeatureSet = featureSets[1];
        WeightVector effectiveParams = linearFunctions[1].effectiveParams();
        ExItExperience[] allExperience = experienceBuffer.allExperience();
        ArrayList arrayList = new ArrayList();
        TFloatArrayList tFloatArrayList = new TFloatArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (ExItExperience exItExperience : allExperience) {
            if (exItExperience != null && exItExperience.moves().size() > 1) {
                FeatureVector[] generateFeatureVectors = exItExperience.generateFeatureVectors(baseFeatureSet);
                float[] fArr = new float[generateFeatureVectors.length];
                for (int i2 = 0; i2 < generateFeatureVectors.length; i2++) {
                    fArr[i2] = effectiveParams.dot(generateFeatureVectors[i2]);
                }
                float max = ArrayUtils.max(fArr);
                float min = ArrayUtils.min(fArr);
                if (max != min) {
                    for (FeatureVector featureVector : generateFeatureVectors) {
                        arrayList.add(featureVector);
                    }
                    int nextSetBit = exItExperience.winningMoves().nextSetBit(0);
                    while (true) {
                        int i3 = nextSetBit;
                        if (i3 < 0) {
                            break;
                        }
                        fArr[i3] = max;
                        nextSetBit = exItExperience.winningMoves().nextSetBit(i3 + 1);
                    }
                    int nextSetBit2 = exItExperience.losingMoves().nextSetBit(0);
                    while (true) {
                        int i4 = nextSetBit2;
                        if (i4 < 0) {
                            break;
                        }
                        fArr[i4] = min;
                        nextSetBit2 = exItExperience.losingMoves().nextSetBit(i4 + 1);
                    }
                    FVector fVector = new FVector(fArr);
                    fVector.softmax();
                    float max2 = fVector.max();
                    float[] fArr2 = new float[fArr.length];
                    for (int i5 = 0; i5 < fArr2.length; i5++) {
                        fArr2[i5] = fVector.get(i5) / max2;
                    }
                    for (float f : fArr2) {
                        tFloatArrayList.add(f);
                    }
                    arrayList2.add(Arrays.asList(generateFeatureVectors));
                    arrayList3.add(TFloatArrayList.wrap(fArr2));
                }
            }
        }
        int numAspatialFeatures = baseFeatureSet.getNumAspatialFeatures();
        int numSpatialFeatures = baseFeatureSet.getNumSpatialFeatures();
        double[] dArr = new double[numAspatialFeatures];
        int[] iArr = new int[numAspatialFeatures];
        double[] dArr2 = new double[numAspatialFeatures];
        int[] iArr2 = new int[numAspatialFeatures];
        for (int i6 = 0; i6 < numAspatialFeatures; i6++) {
            for (int i7 = 0; i7 < arrayList.size(); i7++) {
                FeatureVector featureVector2 = (FeatureVector) arrayList.get(i7);
                float quick = tFloatArrayList.getQuick(i7);
                if (featureVector2.aspatialFeatureValues().get(i6) != ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                    int i8 = i6;
                    dArr2[i8] = dArr2[i8] + quick;
                    int i9 = i6;
                    iArr2[i9] = iArr2[i9] + 1;
                } else {
                    int i10 = i6;
                    dArr[i10] = dArr[i10] + quick;
                    int i11 = i6;
                    iArr[i11] = iArr[i11] + 1;
                }
            }
        }
        double[] dArr3 = new double[numSpatialFeatures];
        int[] iArr3 = new int[numSpatialFeatures];
        double[] dArr4 = new double[numSpatialFeatures];
        int[] iArr4 = new int[numSpatialFeatures];
        for (int i12 = 0; i12 < arrayList.size(); i12++) {
            FeatureVector featureVector3 = (FeatureVector) arrayList.get(i12);
            float quick2 = tFloatArrayList.getQuick(i12);
            boolean[] zArr = new boolean[numSpatialFeatures];
            TIntArrayList activeSpatialFeatureIndices = featureVector3.activeSpatialFeatureIndices();
            for (int i13 = 0; i13 < activeSpatialFeatureIndices.size(); i13++) {
                zArr[activeSpatialFeatureIndices.getQuick(i13)] = true;
            }
            for (int i14 = 0; i14 < zArr.length; i14++) {
                if (zArr[i14]) {
                    int i15 = i14;
                    dArr4[i15] = dArr4[i15] + quick2;
                    int i16 = i14;
                    iArr4[i16] = iArr4[i16] + 1;
                } else {
                    int i17 = i14;
                    dArr3[i17] = dArr3[i17] + quick2;
                    int i18 = i14;
                    iArr3[i18] = iArr3[i18] + 1;
                }
            }
        }
        double[] dArr5 = new double[numAspatialFeatures];
        double[] dArr6 = new double[numAspatialFeatures];
        double[] dArr7 = new double[numSpatialFeatures];
        double[] dArr8 = new double[numSpatialFeatures];
        for (int i19 = 0; i19 < numAspatialFeatures; i19++) {
            if (iArr[i19] > 0) {
                dArr5[i19] = dArr[i19] / iArr[i19];
            }
            if (iArr2[i19] > 0) {
                dArr6[i19] = dArr2[i19] / iArr2[i19];
            }
        }
        for (int i20 = 0; i20 < numSpatialFeatures; i20++) {
            if (iArr3[i20] > 0) {
                dArr7[i20] = dArr3[i20] / iArr3[i20];
            }
            if (iArr4[i20] > 0) {
                dArr8[i20] = dArr4[i20] / iArr4[i20];
            }
        }
        ArrayList<Row> arrayList4 = new ArrayList();
        for (int i21 = 0; i21 < numAspatialFeatures; i21++) {
            arrayList4.add(new Row(baseFeatureSet.aspatialFeatures()[i21]));
        }
        for (int i22 = 0; i22 < numSpatialFeatures; i22++) {
            arrayList4.add(new Row(baseFeatureSet.spatialFeatures()[i22]));
        }
        double d7 = 0.0d;
        double sum = tFloatArrayList.sum() / tFloatArrayList.size();
        for (int i23 = 0; i23 < tFloatArrayList.size(); i23++) {
            double quick3 = tFloatArrayList.getQuick(i23) - sum;
            d7 += quick3 * quick3;
        }
        for (int i24 = 0; i24 < numAspatialFeatures; i24++) {
            Row row = (Row) arrayList4.get(i24);
            double d8 = 0.0d;
            double d9 = 0.0d;
            double d10 = 0.0d;
            for (int i25 = 0; i25 < arrayList.size(); i25++) {
                FeatureVector featureVector4 = (FeatureVector) arrayList.get(i25);
                float quick4 = tFloatArrayList.getQuick(i25);
                if (featureVector4.aspatialFeatureValues().get(i24) != ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                    d6 = quick4 - dArr6[i24];
                    d10 += d6 * d6;
                } else {
                    d6 = quick4 - dArr5[i24];
                    d9 += d6 * d6;
                }
                d8 += d6 * d6;
            }
            row.sse = d8;
            row.reductionSSE = d7 - d8;
            row.sseFalse = d9;
            row.sseTrue = d10;
            row.sampleSizeFalse = iArr[i24];
            row.sampleSizeTrue = iArr2[i24];
            row.meanTargetFalse = dArr5[i24];
            row.meanTargetTrue = dArr6[i24];
        }
        for (int i26 = 0; i26 < numSpatialFeatures; i26++) {
            Row row2 = (Row) arrayList4.get(i26 + numAspatialFeatures);
            double d11 = 0.0d;
            double d12 = 0.0d;
            double d13 = 0.0d;
            for (int i27 = 0; i27 < arrayList.size(); i27++) {
                FeatureVector featureVector5 = (FeatureVector) arrayList.get(i27);
                float quick5 = tFloatArrayList.getQuick(i27);
                if (featureVector5.activeSpatialFeatureIndices().contains(i26)) {
                    d5 = quick5 - dArr8[i26];
                    d13 += d5 * d5;
                } else {
                    d5 = quick5 - dArr7[i26];
                    d12 += d5 * d5;
                }
                d11 += d5 * d5;
            }
            row2.sse = d11;
            row2.reductionSSE = d7 - d11;
            row2.sseFalse = d12;
            row2.sseTrue = d13;
            row2.sampleSizeFalse = iArr3[i26];
            row2.sampleSizeTrue = iArr4[i26];
            row2.meanTargetFalse = dArr7[i26];
            row2.meanTargetTrue = dArr8[i26];
        }
        double d14 = 1.0d - sum;
        for (int i28 = 0; i28 < numAspatialFeatures; i28++) {
            Row row3 = (Row) arrayList4.get(i28);
            double max3 = Math.max(dArr5[i28] / sum, sum / dArr5[i28]);
            double max4 = Math.max(dArr6[i28] / sum, sum / dArr6[i28]);
            row3.urgency = Math.max(max3, max4);
            row3.weightedUrgency = Math.max(Math.log10(iArr[i28]) * Math.max(dArr5[i28] / sum, sum / dArr5[i28]), Math.log10(iArr2[i28]) * Math.max(dArr6[i28] / sum, sum / dArr6[i28]));
            row3.urgencyRatio = Math.max(max3 / max4, max4 / max3);
            double d15 = dArr5[i28] > sum ? (dArr5[i28] - sum) / d14 : (sum - dArr5[i28]) / sum;
            if (dArr6[i28] > sum) {
                d3 = dArr6[i28] - sum;
                d4 = d14;
            } else {
                d3 = sum - dArr6[i28];
                d4 = sum;
            }
            row3.scaledUrgency = Math.max(d15, d3 / d4);
        }
        for (int i29 = 0; i29 < numSpatialFeatures; i29++) {
            Row row4 = (Row) arrayList4.get(i29 + numAspatialFeatures);
            double max5 = Math.max(dArr7[i29] / sum, sum / dArr7[i29]);
            double max6 = Math.max(dArr8[i29] / sum, sum / dArr8[i29]);
            row4.urgency = Math.max(max5, max6);
            row4.weightedUrgency = Math.max(Math.log10(iArr3[i29]) * Math.max(dArr7[i29] / sum, sum / dArr7[i29]), Math.log10(iArr4[i29]) * Math.max(dArr8[i29] / sum, sum / dArr8[i29]));
            row4.urgencyRatio = Math.max(max5 / max6, max6 / max5);
            double d16 = dArr7[i29] > sum ? (dArr7[i29] - sum) / d14 : (sum - dArr7[i29]) / sum;
            if (dArr8[i29] > sum) {
                d = dArr8[i29] - sum;
                d2 = d14;
            } else {
                d = sum - dArr8[i29];
                d2 = sum;
            }
            row4.scaledUrgency = Math.max(d16, d / d2);
        }
        Collections.sort(arrayList4, new Comparator<Row>() { // from class: supplementary.experiments.feature_importance.AnalyseFeatureImportances.1
            @Override // java.util.Comparator
            public int compare(Row row5, Row row6) {
                if (row5.reductionSSE > row6.reductionSSE) {
                    return -1;
                }
                return row5.reductionSSE < row6.reductionSSE ? 1 : 0;
            }
        });
        try {
            PrintWriter printWriter = new PrintWriter(commandLineArgParse.getValueString("--out-file"), XmpWriter.UTF8);
            try {
                printWriter.println("Feature,SSE,ReductionSSE,SseFalse,SseTrue,SampleSizeFalse,SampleSizeTrue,MeanTargetFalse,MeanTargetTrue,Urgency,WeightedUrgency,UrgencyRatio,ScaledUrgency");
                for (Row row5 : arrayList4) {
                    if (row5.sampleSizeFalse > 0 && row5.sampleSizeTrue > 0) {
                        printWriter.println(row5);
                    }
                }
                printWriter.close();
            } finally {
            }
        } catch (IOException e3) {
            e3.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Analyses feature importances for a game.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game-training-dir").help("The directory with training outcomes for the game to analyse.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game-name").help("Name of the game.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--out-file").help("Filepath to write data to.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        if (commandLineArgParse.parseArguments(strArr)) {
            analyseFeatureImportances(commandLineArgParse);
        }
    }
}
