package supplementary.experiments.feature_importance;

import com.itextpdf.text.pdf.ColumnText;
import com.itextpdf.text.xml.xmp.XmpWriter;
import decision_trees.classifiers.DecisionConditionNode;
import decision_trees.classifiers.DecisionTreeNode;
import decision_trees.classifiers.ExperienceUrgencyTreeLearner;
import features.Feature;
import features.FeatureVector;
import features.WeightVector;
import features.aspatial.AspatialFeature;
import features.feature_sets.BaseFeatureSet;
import features.feature_sets.network.JITSPatterNetFeatureSet;
import features.spatial.SpatialFeature;
import function_approx.LinearFunction;
import game.Game;
import game.types.play.RoleType;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import main.CommandLineArgParse;
import main.StringRoutines;
import main.collections.ArrayUtils;
import main.collections.FVector;
import main.collections.ListUtils;
import main.collections.ScoredInt;
import main.math.statistics.IncrementalStats;
import metadata.ai.features.FeatureSet;
import metadata.ai.features.Features;
import metadata.ai.misc.Pair;
import org.apache.batik.constants.XMLConstants;
import other.GameLoader;
import other.RankUtils;
import other.context.Context;
import other.trial.Trial;
import policies.softmax.SoftmaxPolicyLinear;
import search.mcts.MCTS;
import training.expert_iteration.ExItExperience;
import utils.AIFactory;
import utils.ExperimentFileUtils;
import utils.data_structures.experience_buffers.ExperienceBuffer;
import utils.data_structures.experience_buffers.PrioritizedReplayBuffer;
import utils.data_structures.experience_buffers.UniformExperienceBuffer;

/* loaded from: input_file:supplementary/experiments/feature_importance/IdentifyTopFeatures.class */
public class IdentifyTopFeatures {
    private static final int NUM_TRIALS_PER_FEATURE_EVAL = 50;
    private static final int GOAL_NUM_FEATURES = 15;
    private static final int NUM_EVAL_TRIALS_FEATURESET_GROWING = 150;

    private static void identifyTopFeatures(CommandLineArgParse commandLineArgParse) {
        String valueString = commandLineArgParse.getValueString("--game");
        String valueString2 = commandLineArgParse.getValueString("--ruleset");
        Game loadGameFromName = (valueString2 == null || valueString2.equals("")) ? GameLoader.loadGameFromName(valueString) : GameLoader.loadGameFromName(valueString, valueString2);
        if (loadGameFromName == null) {
            throw new IllegalArgumentException("Cannot load game: " + valueString + " (ruleset = " + valueString2 + ")");
        }
        try {
            int count = loadGameFromName.players().count();
            String valueString3 = commandLineArgParse.getValueString("--training-out-dir");
            if (!valueString3.endsWith("/")) {
                valueString3 = valueString3 + "/";
            }
            BaseFeatureSet[] baseFeatureSetArr = new BaseFeatureSet[count + 1];
            LinearFunction[] linearFunctionArr = new LinearFunction[count + 1];
            LinearFunction[] linearFunctionArr2 = new LinearFunction[count + 1];
            if (!loadFeaturesAndWeights(loadGameFromName, valueString3, baseFeatureSetArr, linearFunctionArr, linearFunctionArr2)) {
                System.out.println("Did not manage to load any files for " + valueString + " (ruleset = " + valueString2 + ")");
                return;
            }
            ExperienceBuffer[] experienceBufferArr = new ExperienceBuffer[count + 1];
            loadExperienceBuffers(loadGameFromName, valueString3, experienceBufferArr);
            DecisionTreeNode[] decisionTreeNodeArr = new DecisionTreeNode[count + 1];
            DecisionTreeNode[] decisionTreeNodeArr2 = new DecisionTreeNode[count + 1];
            for (int i = 1; i <= count; i++) {
                decisionTreeNodeArr[i] = ExperienceUrgencyTreeLearner.buildTree(baseFeatureSetArr[i], linearFunctionArr[i], experienceBufferArr[i], 4, 10);
                decisionTreeNodeArr2[i] = ExperienceUrgencyTreeLearner.buildTree(baseFeatureSetArr[i], linearFunctionArr2[i], experienceBufferArr[i], 4, 10);
            }
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList.add(null);
            arrayList2.add(null);
            for (int i2 = 1; i2 <= count; i2++) {
                ArrayList arrayList3 = new ArrayList();
                ArrayList arrayList4 = new ArrayList();
                ArrayList<Feature> arrayList5 = new ArrayList();
                collectFeatures(decisionTreeNodeArr[i2], arrayList5);
                collectFeatures(decisionTreeNodeArr2[i2], arrayList5);
                for (Feature feature : arrayList5) {
                    if (!(feature instanceof AspatialFeature)) {
                        arrayList4.add((SpatialFeature) feature);
                    } else if (!arrayList3.contains(feature)) {
                        arrayList3.add((AspatialFeature) feature);
                    }
                }
                List<SpatialFeature> simplifySpatialFeaturesList = SpatialFeature.simplifySpatialFeaturesList(loadGameFromName, SpatialFeature.deduplicate(arrayList4));
                HashSet hashSet = new HashSet();
                Iterator<SpatialFeature> it = simplifySpatialFeaturesList.iterator();
                while (it.hasNext()) {
                    it.next().generateGeneralisers(loadGameFromName, hashSet, 1);
                }
                Iterator it2 = hashSet.iterator();
                while (it2.hasNext()) {
                    simplifySpatialFeaturesList.add(((SpatialFeature.RotRefInvariantFeature) it2.next()).feature());
                }
                List<SpatialFeature> simplifySpatialFeaturesList2 = SpatialFeature.simplifySpatialFeaturesList(loadGameFromName, SpatialFeature.deduplicate(simplifySpatialFeaturesList));
                arrayList.add(new ArrayList());
                arrayList2.add(simplifySpatialFeaturesList2);
            }
            BaseFeatureSet[] baseFeatureSetArr2 = new BaseFeatureSet[count + 1];
            for (int i3 = 1; i3 <= count; i3++) {
                baseFeatureSetArr2[i3] = JITSPatterNetFeatureSet.construct((List) arrayList.get(i3), (List) arrayList2.get(i3));
                baseFeatureSetArr2[i3].init(loadGameFromName, new int[]{i3}, null);
            }
            FVector[] fVectorArr = new FVector[count + 1];
            for (int i4 = 1; i4 <= count; i4++) {
                fVectorArr[i4] = computeCandidateFeatureWeights(baseFeatureSetArr2[i4], baseFeatureSetArr[i4], linearFunctionArr[i4], experienceBufferArr[i4]);
            }
            Arrays.fill(experienceBufferArr, (Object) null);
            Arrays.fill(baseFeatureSetArr2, (Object) null);
            Arrays.fill(baseFeatureSetArr, (Object) null);
            ArrayList arrayList6 = new ArrayList(count + 1);
            arrayList6.add(null);
            for (int i5 = 1; i5 <= count; i5++) {
                ArrayList arrayList7 = new ArrayList();
                arrayList7.addAll((Collection) arrayList.get(i5));
                arrayList7.addAll((Collection) arrayList2.get(i5));
                arrayList6.add(arrayList7);
            }
            evaluateCandidateFeatures(loadGameFromName, arrayList6, fVectorArr, commandLineArgParse);
        } catch (Exception e) {
            System.err.println("Exception in game: " + valueString + " (ruleset = " + valueString2 + ")");
            e.printStackTrace();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static void evaluateCandidateFeatures(Game game2, List<List<Feature>> list, FVector[] fVectorArr, CommandLineArgParse commandLineArgParse) {
        PrintWriter printWriter;
        String valueString = commandLineArgParse.getValueString("--out-dir");
        if (!valueString.endsWith("/")) {
            valueString = valueString + "/";
        }
        new File(valueString).mkdirs();
        int count = game2.players().count();
        IncrementalStats[] incrementalStatsArr = new IncrementalStats[count + 1];
        TIntArrayList[] tIntArrayListArr = new TIntArrayList[count + 1];
        BaseFeatureSet[] baseFeatureSetArr = new BaseFeatureSet[count + 1];
        LinearFunction[] linearFunctionArr = new LinearFunction[count + 1];
        for (int i = 1; i <= count; i++) {
            incrementalStatsArr[i] = new IncrementalStats[list.get(i).size()];
            for (int i2 = 0; i2 < incrementalStatsArr[i].length; i2++) {
                incrementalStatsArr[i][i2] = new IncrementalStats();
            }
            tIntArrayListArr[i] = ListUtils.range(list.get(i).size());
            linearFunctionArr[i] = new LinearFunction(new WeightVector(FVector.wrap(new float[]{Float.NaN})));
        }
        boolean z = false;
        Context context = new Context(game2, new Trial(game2));
        int i3 = 0;
        while (!z) {
            i3++;
            for (int i4 = 1; i4 <= count; i4++) {
                for (int i5 = 0; i5 < tIntArrayListArr[i4].size(); i5++) {
                    int quick = tIntArrayListArr[i4].getQuick(i5);
                    baseFeatureSetArr[i4] = JITSPatterNetFeatureSet.construct((List<Feature>) Arrays.asList(list.get(i4).get(quick)));
                    baseFeatureSetArr[i4].init(game2, new int[]{i4}, null);
                    linearFunctionArr[i4].trainableParams().allWeights().set(0, fVectorArr[i4].get(quick));
                    for (int i6 = 0; i6 < 50 * i3; i6++) {
                        game2.start(context);
                        int[] iArr = new int[count + 1];
                        iArr[0] = Integer.MIN_VALUE;
                        iArr[i4] = quick;
                        ArrayList arrayList = new ArrayList();
                        arrayList.add(null);
                        for (int i7 = 1; i7 <= count; i7++) {
                            if (i7 != i4) {
                                iArr[i7] = ThreadLocalRandom.current().nextInt(tIntArrayListArr[i7].size());
                                baseFeatureSetArr[i7] = JITSPatterNetFeatureSet.construct((List<Feature>) Arrays.asList(list.get(i7).get(iArr[i7])));
                                linearFunctionArr[i7].trainableParams().allWeights().set(0, fVectorArr[i7].get(iArr[i7]));
                            }
                        }
                        for (int i8 = 1; i8 <= count; i8++) {
                            SoftmaxPolicyLinear softmaxPolicyLinear = new SoftmaxPolicyLinear(linearFunctionArr, baseFeatureSetArr);
                            arrayList.add(softmaxPolicyLinear);
                            softmaxPolicyLinear.initAI(game2, i8);
                        }
                        game2.playout(context, arrayList, 1.0d, null, -1, -1, null);
                        for (int i9 = 1; i9 <= count; i9++) {
                            arrayList.get(i9).closeAI();
                        }
                        double[] agentUtilities = RankUtils.agentUtilities(context);
                        for (int i10 = 1; i10 <= count; i10++) {
                            incrementalStatsArr[i10][iArr[i10]].observe(agentUtilities[i10]);
                        }
                    }
                }
            }
            z = true;
            for (int i11 = 1; i11 <= count; i11++) {
                ArrayList arrayList2 = new ArrayList();
                for (int i12 = 0; i12 < tIntArrayListArr[i11].size(); i12++) {
                    int quick2 = tIntArrayListArr[i11].getQuick(i12);
                    arrayList2.add(new ScoredInt(quick2, incrementalStatsArr[i11][quick2].getMean()));
                }
                Collections.sort(arrayList2, ScoredInt.DESCENDING);
                int min = Math.min(arrayList2.size(), Math.max(15, arrayList2.size() / 2));
                TIntArrayList tIntArrayList = new TIntArrayList();
                for (int i13 = 0; i13 < min; i13++) {
                    tIntArrayList.add(((ScoredInt) arrayList2.get(i13)).object());
                }
                tIntArrayListArr[i11] = tIntArrayList;
                if (min > 15) {
                    z = false;
                }
            }
        }
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(null);
        for (int i14 = 1; i14 <= count; i14++) {
            ArrayList arrayList4 = new ArrayList();
            for (int i15 = 0; i15 < tIntArrayListArr[i14].size(); i15++) {
                int quick3 = tIntArrayListArr[i14].getQuick(i15);
                arrayList4.add(new ScoredInt(quick3, incrementalStatsArr[i14][quick3].getMean()));
            }
            Collections.sort(arrayList4, ScoredInt.DESCENDING);
            arrayList3.add(arrayList4);
        }
        try {
            printWriter = new PrintWriter(valueString + "RankedFeatures.txt", XmpWriter.UTF8);
            for (int i16 = 1; i16 <= count; i16++) {
                try {
                    printWriter.println("Scores for Player " + i16);
                    Iterator it = ((List) arrayList3.get(i16)).iterator();
                    while (it.hasNext()) {
                        int object = ((ScoredInt) it.next()).object();
                        printWriter.println("Feature=" + list.get(i16).get(object) + ",weight=" + fVectorArr[i16].get(object) + ",score=" + incrementalStatsArr[i16][object].getMean());
                    }
                    printWriter.println();
                } finally {
                }
            }
            printWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        BaseFeatureSet[] baseFeatureSetArr2 = new BaseFeatureSet[count + 1];
        LinearFunction[] linearFunctionArr2 = new LinearFunction[count + 1];
        for (int i17 = 1; i17 <= count; i17++) {
            int object2 = ((ScoredInt) ((List) arrayList3.get(i17)).get(0)).object();
            baseFeatureSetArr2[i17] = JITSPatterNetFeatureSet.construct((List<Feature>) Arrays.asList(list.get(i17).get(object2)));
            linearFunctionArr2[i17] = new LinearFunction(new WeightVector(FVector.wrap(new float[]{fVectorArr[i17].get(object2)})));
        }
        for (int i18 = 1; i18 < 15; i18++) {
            IncrementalStats[] incrementalStatsArr2 = new IncrementalStats[count + 1];
            for (int i19 = 1; i19 <= count; i19++) {
                incrementalStatsArr2[i19] = new IncrementalStats();
            }
            ArrayList arrayList5 = new ArrayList();
            arrayList5.add(null);
            for (int i20 = 1; i20 <= count; i20++) {
                arrayList5.add(new SoftmaxPolicyLinear(linearFunctionArr2, baseFeatureSetArr2));
            }
            for (int i21 = 0; i21 < 150; i21++) {
                game2.start(context);
                for (int i22 = 1; i22 <= count; i22++) {
                    arrayList5.get(i22).initAI(game2, i22);
                }
                game2.playout(context, arrayList5, 1.0d, null, -1, -1, null);
                double[] agentUtilities2 = RankUtils.agentUtilities(context);
                for (int i23 = 1; i23 <= count; i23++) {
                    incrementalStatsArr2[i23].observe(agentUtilities2[i23]);
                }
            }
            boolean[] zArr = new boolean[count + 1];
            for (int i24 = 1; i24 <= count; i24++) {
                if (i18 < ((List) arrayList3.get(i24)).size()) {
                    ArrayList arrayList6 = new ArrayList();
                    arrayList6.add(null);
                    for (int i25 = 1; i25 <= count; i25++) {
                        if (i25 == i24) {
                            int object3 = ((ScoredInt) ((List) arrayList3.get(i25)).get(i18)).object();
                            Feature feature = list.get(i25).get(object3);
                            ArrayList arrayList7 = new ArrayList();
                            for (AspatialFeature aspatialFeature : baseFeatureSetArr2[i25].aspatialFeatures()) {
                                arrayList7.add(aspatialFeature);
                            }
                            for (SpatialFeature spatialFeature : baseFeatureSetArr2[i25].spatialFeatures()) {
                                arrayList7.add(spatialFeature);
                            }
                            arrayList7.add(feature);
                            JITSPatterNetFeatureSet construct = JITSPatterNetFeatureSet.construct(arrayList7);
                            LinearFunction linearFunction = new LinearFunction(new WeightVector(new FVector(arrayList7.size())));
                            linearFunction.trainableParams().allWeights().copyFrom(linearFunctionArr2[i25].trainableParams().allWeights(), 0, 0, arrayList7.size() - 1);
                            linearFunction.trainableParams().allWeights().set(arrayList7.size() - 1, fVectorArr[i25].get(object3));
                            BaseFeatureSet[] baseFeatureSetArr3 = (BaseFeatureSet[]) Arrays.copyOf(baseFeatureSetArr2, count + 1);
                            baseFeatureSetArr3[i25] = construct;
                            LinearFunction[] linearFunctionArr3 = (LinearFunction[]) Arrays.copyOf(linearFunctionArr2, count + 1);
                            linearFunctionArr3[i25] = linearFunction;
                            arrayList6.add(new SoftmaxPolicyLinear(linearFunctionArr3, baseFeatureSetArr3));
                        } else {
                            arrayList6.add(arrayList5.get(i25));
                        }
                    }
                    IncrementalStats incrementalStats = new IncrementalStats();
                    for (int i26 = 0; i26 < 150; i26++) {
                        game2.start(context);
                        for (int i27 = 1; i27 <= count; i27++) {
                            arrayList6.get(i27).initAI(game2, i27);
                        }
                        game2.playout(context, arrayList6, 1.0d, null, -1, -1, null);
                        for (int i28 = 1; i28 <= count; i28++) {
                            arrayList6.get(i28).closeAI();
                        }
                        incrementalStats.observe(RankUtils.agentUtilities(context)[i24]);
                    }
                    if (incrementalStats.getMean() > incrementalStatsArr2[i24].getMean()) {
                        zArr[i24] = true;
                    }
                }
            }
            for (int i29 = 1; i29 <= count; i29++) {
                if (zArr[i29]) {
                    int object4 = ((ScoredInt) ((List) arrayList3.get(i29)).get(i18)).object();
                    Feature feature2 = list.get(i29).get(object4);
                    ArrayList arrayList8 = new ArrayList();
                    for (AspatialFeature aspatialFeature2 : baseFeatureSetArr2[i29].aspatialFeatures()) {
                        arrayList8.add(aspatialFeature2);
                    }
                    for (SpatialFeature spatialFeature2 : baseFeatureSetArr2[i29].spatialFeatures()) {
                        arrayList8.add(spatialFeature2);
                    }
                    arrayList8.add(feature2);
                    JITSPatterNetFeatureSet construct2 = JITSPatterNetFeatureSet.construct(arrayList8);
                    LinearFunction linearFunction2 = new LinearFunction(new WeightVector(new FVector(arrayList8.size())));
                    linearFunction2.trainableParams().allWeights().copyFrom(linearFunctionArr2[i29].trainableParams().allWeights(), 0, 0, arrayList8.size() - 1);
                    linearFunction2.trainableParams().allWeights().set(arrayList8.size() - 1, fVectorArr[i29].get(object4));
                    baseFeatureSetArr2[i29] = construct2;
                    linearFunctionArr2[i29] = linearFunction2;
                }
            }
        }
        FeatureSet[] featureSetArr = new FeatureSet[count];
        for (int i30 = 1; i30 <= count; i30++) {
            Pair[] pairArr = new Pair[baseFeatureSetArr2[i30].getNumFeatures()];
            SpatialFeature[] spatialFeatures = baseFeatureSetArr2[i30].spatialFeatures();
            FVector allWeights = linearFunctionArr2[i30].trainableParams().allWeights();
            for (int i31 = 0; i31 < pairArr.length; i31++) {
                pairArr[i31] = new Pair(spatialFeatures[i31].toString(), Float.valueOf(allWeights.get(i31)));
            }
            featureSetArr[i30 - 1] = new FeatureSet(RoleType.roleForPlayerId(i30), pairArr);
        }
        Features features2 = new Features(featureSetArr);
        try {
            printWriter = new PrintWriter(valueString + "BestFeatures.txt", XmpWriter.UTF8);
            try {
                printWriter.print(features2.toString());
                printWriter.close();
            } finally {
            }
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    private static FVector computeCandidateFeatureWeights(BaseFeatureSet baseFeatureSet, BaseFeatureSet baseFeatureSet2, LinearFunction linearFunction, ExperienceBuffer experienceBuffer) {
        WeightVector effectiveParams = linearFunction.effectiveParams();
        ExItExperience[] allExperience = experienceBuffer.allExperience();
        ArrayList arrayList = new ArrayList();
        TFloatArrayList tFloatArrayList = new TFloatArrayList();
        for (ExItExperience exItExperience : allExperience) {
            if (exItExperience != null && exItExperience.moves().size() > 1) {
                FeatureVector[] generateFeatureVectors = exItExperience.generateFeatureVectors(baseFeatureSet2);
                FeatureVector[] generateFeatureVectors2 = exItExperience.generateFeatureVectors(baseFeatureSet);
                float[] fArr = new float[generateFeatureVectors.length];
                for (int i = 0; i < generateFeatureVectors.length; i++) {
                    fArr[i] = effectiveParams.dot(generateFeatureVectors[i]);
                }
                float max = ArrayUtils.max(fArr);
                float min = ArrayUtils.min(fArr);
                if (max != min) {
                    int nextSetBit = exItExperience.winningMoves().nextSetBit(0);
                    while (true) {
                        int i2 = nextSetBit;
                        if (i2 < 0) {
                            break;
                        }
                        fArr[i2] = max;
                        nextSetBit = exItExperience.winningMoves().nextSetBit(i2 + 1);
                    }
                    int nextSetBit2 = exItExperience.losingMoves().nextSetBit(0);
                    while (true) {
                        int i3 = nextSetBit2;
                        if (i3 < 0) {
                            break;
                        }
                        fArr[i3] = min;
                        nextSetBit2 = exItExperience.losingMoves().nextSetBit(i3 + 1);
                    }
                    for (int i4 = 0; i4 < generateFeatureVectors2.length; i4++) {
                        arrayList.add(generateFeatureVectors2[i4]);
                        tFloatArrayList.add(fArr[i4]);
                    }
                }
            }
        }
        FVector fVector = new FVector(baseFeatureSet.getNumFeatures());
        IncrementalStats[] incrementalStatsArr = new IncrementalStats[fVector.dim()];
        IncrementalStats[] incrementalStatsArr2 = new IncrementalStats[fVector.dim()];
        IncrementalStats incrementalStats = new IncrementalStats();
        for (int i5 = 0; i5 < incrementalStatsArr.length; i5++) {
            incrementalStatsArr[i5] = new IncrementalStats();
            incrementalStatsArr2[i5] = new IncrementalStats();
        }
        int numAspatialFeatures = baseFeatureSet.getNumAspatialFeatures();
        int numSpatialFeatures = baseFeatureSet.getNumSpatialFeatures();
        for (int i6 = 0; i6 < arrayList.size(); i6++) {
            FeatureVector featureVector = (FeatureVector) arrayList.get(i6);
            float quick = tFloatArrayList.getQuick(i6);
            incrementalStats.observe(quick);
            FVector aspatialFeatureValues = featureVector.aspatialFeatureValues();
            TIntArrayList activeSpatialFeatureIndices = featureVector.activeSpatialFeatureIndices();
            for (int i7 = 0; i7 < aspatialFeatureValues.dim(); i7++) {
                if (aspatialFeatureValues.get(i7) != ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                    incrementalStatsArr[i7].observe(quick);
                } else {
                    incrementalStatsArr2[i7].observe(quick);
                }
            }
            boolean[] zArr = new boolean[numSpatialFeatures];
            for (int i8 = 0; i8 < activeSpatialFeatureIndices.size(); i8++) {
                int quick2 = activeSpatialFeatureIndices.getQuick(i8);
                incrementalStatsArr[quick2 + numAspatialFeatures].observe(quick);
                zArr[quick2] = true;
            }
            for (int i9 = 0; i9 < numSpatialFeatures; i9++) {
                if (!zArr[i9]) {
                    incrementalStatsArr2[i9 + numAspatialFeatures].observe(quick);
                }
            }
        }
        for (int i10 = 0; i10 < fVector.dim(); i10++) {
            fVector.set(i10, (float) (incrementalStatsArr[i10].getMean() - incrementalStats.getMean()));
        }
        return fVector;
    }

    private static void collectFeatures(DecisionTreeNode decisionTreeNode, List<Feature> list) {
        if (decisionTreeNode instanceof DecisionConditionNode) {
            DecisionConditionNode decisionConditionNode = (DecisionConditionNode) decisionTreeNode;
            list.add(decisionConditionNode.feature());
            collectFeatures(decisionConditionNode.trueNode(), list);
            collectFeatures(decisionConditionNode.falseNode(), list);
        }
    }

    private static boolean loadFeaturesAndWeights(Game game2, String str, BaseFeatureSet[] baseFeatureSetArr, LinearFunction[] linearFunctionArr, LinearFunction[] linearFunctionArr2) {
        try {
            StringBuilder sb = new StringBuilder();
            sb.append("playout=softmax");
            for (int i = 1; i <= game2.players().count(); i++) {
                String lastFilepath = ExperimentFileUtils.getLastFilepath(str + "PolicyWeightsPlayout_P" + i, "txt");
                if (lastFilepath == null) {
                    return false;
                }
                sb.append(",policyweights" + i + XMLConstants.XML_EQUAL_SIGN + lastFilepath);
            }
            SoftmaxPolicyLinear softmaxPolicyLinear = (SoftmaxPolicyLinear) ((MCTS) AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", sb.toString(), "final_move=robustchild", "tree_reuse=true", "learned_selection_policy=playout", "friendly_name=BiasedMCTS"))).playoutStrategy();
            BaseFeatureSet[] featureSets = softmaxPolicyLinear.featureSets();
            LinearFunction[] linearFunctions = softmaxPolicyLinear.linearFunctions();
            softmaxPolicyLinear.initAI(game2, -1);
            System.arraycopy(featureSets, 0, baseFeatureSetArr, 0, featureSets.length);
            System.arraycopy(linearFunctions, 0, linearFunctionArr, 0, linearFunctions.length);
            StringBuilder sb2 = new StringBuilder();
            sb2.append("playout=softmax");
            for (int i2 = 1; i2 <= game2.players().count(); i2++) {
                String lastFilepath2 = ExperimentFileUtils.getLastFilepath(str + "PolicyWeightsTSPG_P" + i2, "txt");
                if (lastFilepath2 == null) {
                    return false;
                }
                sb2.append(",policyweights" + i2 + XMLConstants.XML_EQUAL_SIGN + lastFilepath2);
            }
            sb2.append(",boosted=true");
            SoftmaxPolicyLinear softmaxPolicyLinear2 = (SoftmaxPolicyLinear) ((MCTS) AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", sb2.toString(), "final_move=robustchild", "tree_reuse=true", "learned_selection_policy=playout", "friendly_name=BiasedMCTS"))).playoutStrategy();
            LinearFunction[] linearFunctions2 = softmaxPolicyLinear2.linearFunctions();
            softmaxPolicyLinear2.initAI(game2, -1);
            System.arraycopy(linearFunctions2, 0, linearFunctionArr2, 0, linearFunctions2.length);
            return true;
        } catch (Exception e) {
            e.printStackTrace(System.out);
            return false;
        }
    }

    private static void loadExperienceBuffers(Game game2, String str, ExperienceBuffer[] experienceBufferArr) {
        for (int i = 1; i < experienceBufferArr.length; i++) {
            String lastFilepath = ExperimentFileUtils.getLastFilepath(str + "ExperienceBuffer_P" + i, "buf");
            ExperienceBuffer experienceBuffer = null;
            try {
                experienceBuffer = PrioritizedReplayBuffer.fromFile(game2, lastFilepath);
            } catch (Exception e) {
                if (experienceBuffer == null) {
                    try {
                        experienceBuffer = UniformExperienceBuffer.fromFile(game2, lastFilepath);
                    } catch (Exception e2) {
                        e.printStackTrace();
                        e2.printStackTrace();
                    }
                }
            }
            experienceBufferArr[i] = experienceBuffer;
        }
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Identify top features for a game.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--training-out-dir").help("Directory with training results (features, weights, experience buffers, ...).").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--out-dir").help("Directory in which to write our outputs.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game").help("Name of the game.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--ruleset").help("Ruleset name.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).withDefault(""));
        if (commandLineArgParse.parseArguments(strArr)) {
            System.out.println("Identifying top features for args: " + Arrays.toString(strArr));
            identifyTopFeatures(commandLineArgParse);
        }
    }
}
