package expert_iteration;

import features.FeatureSet;
import features.FeatureUtils;
import features.elements.FeatureElement;
import features.elements.RelativeFeatureElement;
import features.features.Feature;
import features.generation.AtomicFeatureGenerator;
import features.instances.FeatureInstance;
import features.patterns.Pattern;
import function_approx.BoostedLinearFunction;
import function_approx.LinearFunction;
import game.Game;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import language.compiler.Compiler;
import main.CommandLineArgParse;
import main.FileHandling;
import main.collections.FVector;
import main.collections.FastArrayList;
import main.grammar.Report;
import metadata.ai.features.Features;
import metadata.ai.heuristics.Heuristics;
import metadata.ai.misc.BestAgent;
import optimisers.Optimiser;
import optimisers.OptimiserFactory;
import policies.softmax.SoftmaxPolicy;
import search.mcts.MCTS;
import search.mcts.finalmoveselection.ActRegPolOpt;
import search.mcts.finalmoveselection.RobustChild;
import search.mcts.selection.AG0Selection;
import search.mcts.selection.SearchRegPolOpt;
import search.mcts.utils.RegPolOptMCTS;
import search.minimax.AlphaBetaSearch;
import util.AI;
import util.Context;
import util.GameLoader;
import util.Move;
import util.Trial;
import utils.AIFactory;
import utils.AIUtils;
import utils.ExperimentFileUtils;
import utils.ExponentialMovingAverage;
import utils.data_structures.experience_buffers.ExperienceBuffer;
import utils.data_structures.experience_buffers.PrioritizedReplayBuffer;
import utils.data_structures.experience_buffers.UniformExperienceBuffer;
import utils.experiments.InterruptableExperiment;

/* loaded from: input_file:expert_iteration/ExpertIteration.class */
public class ExpertIteration {
    private static final String gameCheckpointFormat = "%s_%05d.%s";
    private static final String weightUpdateCheckpointFormat = "%s_%08d.%s";
    protected String gameName;
    public List<String> gameOptions;
    protected String expertAI;
    protected String bestAgentsDataDir;
    protected int numTrainingGames;
    public int gameLengthCap;
    protected double thinkingTime;
    protected int iterationLimit;
    protected int depthLimit;
    protected int addFeatureEvery;
    protected int batchSize;
    protected int experienceBufferSize;
    protected int updateWeightsEvery;
    protected boolean noGrowFeatureSet;
    protected boolean trainTSPG;
    protected String crossEntropyOptimiserConfig;
    protected String ceExploreOptimiserConfig;
    protected String tspgOptimiserConfig;
    protected String valueOptimiserConfig;
    protected int combiningFeatureInstanceThreshold;
    protected boolean importanceSamplingEpisodeDurations;
    protected boolean prioritizedExperienceReplay;
    protected boolean ceExplore;
    protected float ceExploreMix;
    protected double ceExploreGamma;
    protected boolean ceExploreUniform;
    protected boolean noCEExploreIS;
    protected boolean weightedImportanceSampling;
    protected boolean noValueLearning;
    protected boolean mctsRegPolOpt;
    protected int maxNumBiasedPlayoutActions;
    protected boolean noPruneInitFeatures;
    protected int pruneInitFeaturesThreshold;
    protected int numPruningGames;
    protected int maxNumPruningSeconds;
    protected File outDir;
    protected CheckpointTypes checkpointType;
    protected int checkpointFrequency;
    protected boolean noLogging;
    protected boolean useGUI;
    protected int maxWallTime;

    /* loaded from: input_file:expert_iteration/ExpertIteration$CheckpointTypes.class */
    public enum CheckpointTypes {
        Game,
        WeightUpdate
    }

    public ExpertIteration() {
    }

    public ExpertIteration(boolean z) {
        this.useGUI = z;
    }

    public ExpertIteration(boolean z, int i) {
        this.useGUI = z;
        this.maxWallTime = i;
    }

    public void startExperiment() {
        PrintWriter createLogWriter = createLogWriter();
        Throwable th = null;
        try {
            startTraining(createLogWriter);
            if (createLogWriter != null) {
                if (0 == 0) {
                    createLogWriter.close();
                    return;
                }
                try {
                    createLogWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (createLogWriter != null) {
                if (0 != 0) {
                    try {
                        createLogWriter.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    createLogWriter.close();
                }
            }
            throw th3;
        }
    }

    private void startTraining(final PrintWriter printWriter) {
        final Game loadGameFromName = GameLoader.loadGameFromName(this.gameName, this.gameOptions);
        final int count = loadGameFromName.players().count();
        if (this.gameLengthCap >= 0) {
            loadGameFromName.setMaxTurns(Math.min(this.gameLengthCap, loadGameFromName.getMaxTurnLimit()));
        }
        new InterruptableExperiment(this.useGUI, this.maxWallTime) { // from class: expert_iteration.ExpertIteration.1
            protected long lastCheckpoint;
            protected String[] currentFeatureSetFilenames;
            protected String[] currentPolicyWeightsCEFilenames;
            protected String[] currentPolicyWeightsTSPGFilenames;
            protected String[] currentPolicyWeightsCEEFilenames;
            protected String currentValueFunctionFilename;
            protected String[] currentExperienceBufferFilenames;
            protected String[] currentGameDurationTrackerFilenames;
            protected String[] currentOptimiserCEFilenames;
            protected String[] currentOptimiserTSPGFilenames;
            protected String[] currentOptimiserCEEFilenames;
            protected String currentOptimiserValueFilename;

            /* JADX INFO: Access modifiers changed from: package-private */
            /* renamed from: expert_iteration.ExpertIteration$1$CombinableFeatureInstancePair */
            /* loaded from: input_file:expert_iteration/ExpertIteration$1$CombinableFeatureInstancePair.class */
            public final class CombinableFeatureInstancePair {
                public final FeatureInstance a;
                public final FeatureInstance b;
                protected final Feature combinedFeature;
                private int cachedHash = Integer.MIN_VALUE;

                public CombinableFeatureInstancePair(Game game2, FeatureInstance featureInstance, FeatureInstance featureInstance2) {
                    this.a = featureInstance;
                    this.b = featureInstance2;
                    if (featureInstance.feature().featureSetIndex() < featureInstance2.feature().featureSetIndex()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance, featureInstance2);
                        return;
                    }
                    if (featureInstance2.feature().featureSetIndex() < featureInstance.feature().featureSetIndex()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance2, featureInstance);
                        return;
                    }
                    if (featureInstance.reflection() > featureInstance2.reflection()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance, featureInstance2);
                        return;
                    }
                    if (featureInstance2.reflection() > featureInstance.reflection()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance2, featureInstance);
                        return;
                    }
                    if (featureInstance.rotation() < featureInstance2.rotation()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance, featureInstance2);
                        return;
                    }
                    if (featureInstance2.rotation() < featureInstance.rotation()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance2, featureInstance);
                        return;
                    }
                    if (featureInstance.anchorSite() < featureInstance2.anchorSite()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance, featureInstance2);
                    } else if (featureInstance2.anchorSite() < featureInstance.anchorSite()) {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance2, featureInstance);
                    } else {
                        this.combinedFeature = Feature.combineFeatures(game2, featureInstance, featureInstance2);
                    }
                }

                public boolean equals(Object obj) {
                    if (obj instanceof CombinableFeatureInstancePair) {
                        return this.combinedFeature.equals(((CombinableFeatureInstancePair) obj).combinedFeature);
                    }
                    return false;
                }

                public int hashCode() {
                    if (this.cachedHash == Integer.MIN_VALUE) {
                        this.cachedHash = this.combinedFeature.hashCode();
                    }
                    return this.cachedHash;
                }

                public String toString() {
                    return this.combinedFeature + " (from " + this.a + " and " + this.b + ")";
                }
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* renamed from: expert_iteration.ExpertIteration$1$ScoredPair */
            /* loaded from: input_file:expert_iteration/ExpertIteration$1$ScoredPair.class */
            public final class ScoredPair {
                public final CombinableFeatureInstancePair pair;
                public final double score;

                public ScoredPair(CombinableFeatureInstancePair combinableFeatureInstancePair, double d) {
                    this.pair = combinableFeatureInstancePair;
                    this.score = d;
                }
            }

            private void initMembers() {
                this.lastCheckpoint = Long.MAX_VALUE;
                this.currentFeatureSetFilenames = new String[count + 1];
                this.currentPolicyWeightsCEFilenames = new String[count + 1];
                this.currentPolicyWeightsTSPGFilenames = new String[count + 1];
                this.currentPolicyWeightsCEEFilenames = new String[count + 1];
                this.currentValueFunctionFilename = null;
                this.currentExperienceBufferFilenames = new String[count + 1];
                this.currentGameDurationTrackerFilenames = new String[count + 1];
                this.currentOptimiserCEFilenames = new String[count + 1];
                this.currentOptimiserTSPGFilenames = new String[count + 1];
                this.currentOptimiserCEEFilenames = new String[count + 1];
                this.currentOptimiserValueFilename = null;
            }

            @Override // utils.experiments.InterruptableExperiment
            public void runExperiment() {
                Move move;
                FVector mean;
                AI alphaBetaSearch;
                if (ExpertIteration.this.outDir == null) {
                    System.err.println("Warning: we're not writing any output files for this run!");
                } else if (!ExpertIteration.this.outDir.exists()) {
                    ExpertIteration.this.outDir.mkdirs();
                }
                initMembers();
                FeatureSet[] prepareFeatureSets = prepareFeatureSets();
                LinearFunction[] prepareCrossEntropyFunctions = prepareCrossEntropyFunctions(prepareFeatureSets);
                LinearFunction[] prepareTSPGFunctions = prepareTSPGFunctions(prepareFeatureSets, prepareCrossEntropyFunctions);
                LinearFunction[] prepareCEExploreFunctions = prepareCEExploreFunctions(prepareFeatureSets);
                SoftmaxPolicy softmaxPolicy = new SoftmaxPolicy(prepareCrossEntropyFunctions, prepareFeatureSets, ExpertIteration.this.maxNumBiasedPlayoutActions);
                SoftmaxPolicy softmaxPolicy2 = new SoftmaxPolicy(prepareTSPGFunctions, prepareFeatureSets, ExpertIteration.this.maxNumBiasedPlayoutActions);
                SoftmaxPolicy softmaxPolicy3 = new SoftmaxPolicy(prepareCEExploreFunctions, prepareFeatureSets);
                Heuristics prepareValueFunction = prepareValueFunction();
                Optimiser[] prepareCrossEntropyOptimisers = prepareCrossEntropyOptimisers();
                Optimiser[] prepareTSPGOptimisers = prepareTSPGOptimisers();
                Optimiser[] prepareCEExploreOptimisers = prepareCEExploreOptimisers();
                Optimiser prepareValueFunctionOptimiser = prepareValueFunctionOptimiser();
                Context context = new Context(loadGameFromName, new Trial(loadGameFromName));
                ArrayList arrayList = new ArrayList(count + 1);
                arrayList.add(null);
                for (int i = 1; i <= count; i++) {
                    Report report = new Report();
                    if (ExpertIteration.this.expertAI.equals("BEST_AGENT")) {
                        try {
                            BestAgent bestAgent = (BestAgent) Compiler.compileObject(FileHandling.loadTextContentsFromFile(ExpertIteration.this.bestAgentsDataDir + "/BestAgent.txt"), "metadata.ai.misc.BestAgent", report);
                            if (bestAgent.agent().equals("AlphaBeta") || bestAgent.agent().equals("Alpha-Beta")) {
                                alphaBetaSearch = new AlphaBetaSearch(ExpertIteration.this.bestAgentsDataDir + "/BestHeuristics.txt");
                            } else if (bestAgent.agent().equals("AlphaBetaMetadata")) {
                                alphaBetaSearch = new AlphaBetaSearch();
                            } else if (bestAgent.agent().equals("UCT")) {
                                alphaBetaSearch = AIFactory.createAI("UCT");
                            } else if (bestAgent.agent().equals("MC-GRAVE")) {
                                alphaBetaSearch = AIFactory.createAI("MC-GRAVE");
                            } else if (bestAgent.agent().equals("Biased MCTS")) {
                                alphaBetaSearch = MCTS.createBiasedMCTS((Features) Compiler.compileObject(FileHandling.loadTextContentsFromFile(ExpertIteration.this.bestAgentsDataDir + "/BestFeatures.txt"), "metadata.ai.features.Features", report), true);
                            } else if (bestAgent.agent().equals("Biased MCTS (Uniform Playouts)")) {
                                alphaBetaSearch = MCTS.createBiasedMCTS((Features) Compiler.compileObject(FileHandling.loadTextContentsFromFile(ExpertIteration.this.bestAgentsDataDir + "/BestFeatures.txt"), "metadata.ai.features.Features", report), false);
                            } else if (bestAgent.agent().equals("Biased MCTS (RegPolOpt)")) {
                                alphaBetaSearch = MCTS.createRegPolOptMCTS((Features) Compiler.compileObject(FileHandling.loadTextContentsFromFile(ExpertIteration.this.bestAgentsDataDir + "/BestFeatures.txt"), "metadata.ai.features.Features", report), true);
                            } else if (bestAgent.agent().equals("Biased MCTS (RegPolOpt, Uniform Playouts)")) {
                                alphaBetaSearch = MCTS.createRegPolOptMCTS((Features) Compiler.compileObject(FileHandling.loadTextContentsFromFile(ExpertIteration.this.bestAgentsDataDir + "/BestFeatures.txt"), "metadata.ai.features.Features", report), false);
                            } else {
                                if (!bestAgent.agent().equals("Random")) {
                                    System.err.println("Unrecognised best agent: " + bestAgent.agent());
                                    return;
                                }
                                alphaBetaSearch = MCTS.createUCT();
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                            return;
                        }
                    } else if (ExpertIteration.this.expertAI.equals("FROM_METADATA")) {
                        alphaBetaSearch = AIFactory.fromMetadata(loadGameFromName);
                        if (alphaBetaSearch == null) {
                            System.err.println("AI from metadata is null!");
                            return;
                        } else if (!(alphaBetaSearch instanceof ExpertPolicy)) {
                            System.err.println("AI from metadata is not an expert policy!");
                            return;
                        }
                    } else {
                        if (!ExpertIteration.this.expertAI.equals("Biased MCTS")) {
                            System.err.println("Cannot recognise expert AI: " + ExpertIteration.this.expertAI);
                            return;
                        }
                        MCTS mcts = ExpertIteration.this.mctsRegPolOpt ? new MCTS(new SearchRegPolOpt(), softmaxPolicy, new ActRegPolOpt()) : new MCTS(new AG0Selection(), softmaxPolicy, new RobustChild());
                        mcts.setLearnedSelectionPolicy(softmaxPolicy);
                        mcts.friendlyName = "Biased MCTS";
                        alphaBetaSearch = mcts;
                    }
                    if (alphaBetaSearch instanceof MCTS) {
                        ((MCTS) alphaBetaSearch).setPreserveRootNode(true);
                    } else if (ExpertIteration.this.trainTSPG) {
                        System.err.println("A non-MCTS expert cannot be used for training the TSPG objective!");
                        return;
                    }
                    arrayList.add((ExpertPolicy) alphaBetaSearch);
                }
                ExperienceBuffer[] prepareExperienceBuffers = prepareExperienceBuffers(ExpertIteration.this.prioritizedExperienceReplay);
                ExponentialMovingAverage[] prepareGameDurationTrackers = prepareGameDurationTrackers();
                long j = 0;
                long j2 = ExpertIteration.this.checkpointType == CheckpointTypes.WeightUpdate ? this.lastCheckpoint : 0L;
                int i2 = 0;
                if (ExpertIteration.this.checkpointType == CheckpointTypes.Game && this.lastCheckpoint >= 0) {
                    i2 = (int) this.lastCheckpoint;
                    ExpertIteration.this.numTrainingGames = (int) (r0.numTrainingGames + this.lastCheckpoint);
                }
                while (true) {
                    if (i2 >= ExpertIteration.this.numTrainingGames) {
                        break;
                    }
                    checkWallTime(0.05d);
                    if (this.interrupted) {
                        logLine(printWriter, "interrupting experiment...");
                        break;
                    }
                    saveCheckpoints(i2, j2, prepareFeatureSets, prepareCrossEntropyFunctions, prepareTSPGFunctions, prepareCEExploreFunctions, prepareValueFunction, prepareExperienceBuffers, prepareCrossEntropyOptimisers, prepareTSPGOptimisers, prepareCEExploreOptimisers, prepareValueFunctionOptimiser, prepareGameDurationTrackers, false);
                    FeatureSet[] featureSetArr = new FeatureSet[count + 1];
                    if (!ExpertIteration.this.noGrowFeatureSet && i2 > 0 && i2 % ExpertIteration.this.addFeatureEvery == 0) {
                        for (int i3 = 1; i3 <= count; i3++) {
                            ExItExperience[] sampleExperienceBatchUniformly = prepareExperienceBuffers[i3].sampleExperienceBatchUniformly(ExpertIteration.this.batchSize);
                            if (sampleExperienceBatchUniformly.length > 0) {
                                long currentTimeMillis = System.currentTimeMillis();
                                FeatureSet expandFeatureSetCorrelationBased = expandFeatureSetCorrelationBased(sampleExperienceBatchUniformly, prepareFeatureSets[i3], softmaxPolicy, loadGameFromName, ExpertIteration.this.combiningFeatureInstanceThreshold);
                                if (expandFeatureSetCorrelationBased != null) {
                                    featureSetArr[i3] = expandFeatureSetCorrelationBased;
                                    expandFeatureSetCorrelationBased.instantiateFeatures(loadGameFromName, new int[]{i3}, null);
                                } else {
                                    featureSetArr[i3] = prepareFeatureSets[i3];
                                }
                                logLine(printWriter, "Expanded feature set in " + (System.currentTimeMillis() - currentTimeMillis) + " ms for P" + i3 + ".");
                            } else {
                                featureSetArr[i3] = prepareFeatureSets[i3];
                            }
                        }
                        softmaxPolicy.updateFeatureSets(featureSetArr);
                        if (ExpertIteration.this.trainTSPG) {
                            softmaxPolicy2.updateFeatureSets(featureSetArr);
                        }
                        if (ExpertIteration.this.ceExplore) {
                            softmaxPolicy3.updateFeatureSets(featureSetArr);
                        }
                        prepareFeatureSets = featureSetArr;
                    }
                    logLine(printWriter, "starting game " + (i2 + 1));
                    loadGameFromName.start(context);
                    ArrayList arrayList2 = new ArrayList(count + 1);
                    arrayList2.add(null);
                    for (int i4 = 1; i4 < arrayList.size(); i4++) {
                        ((ExpertPolicy) arrayList.get(i4)).initAI(loadGameFromName, i4);
                        arrayList2.add(new ArrayList());
                    }
                    double d = 1.0d;
                    ArrayList arrayList3 = new ArrayList();
                    TIntArrayList tIntArrayList = new TIntArrayList();
                    TFloatArrayList tFloatArrayList = new TFloatArrayList();
                    while (true) {
                        if (context.trial().over()) {
                            break;
                        }
                        if (this.interrupted) {
                            logLine(printWriter, "interrupting experiment...");
                            break;
                        }
                        int mover = context.state().mover();
                        ExpertPolicy expertPolicy = (ExpertPolicy) arrayList.get(context.state().playerToAgent(mover));
                        expertPolicy.selectAction(loadGameFromName, new Context(context), ExpertIteration.this.thinkingTime, ExpertIteration.this.iterationLimit, ExpertIteration.this.depthLimit);
                        FastArrayList<Move> fastArrayList = new FastArrayList<>();
                        Iterator<Move> it = expertPolicy.lastSearchRootMoves().iterator();
                        while (it.hasNext()) {
                            fastArrayList.add(it.next());
                        }
                        FVector computePiBar = ExpertIteration.this.mctsRegPolOpt ? RegPolOptMCTS.computePiBar(((MCTS) expertPolicy).rootNode(), 2.5d) : expertPolicy.computeExpertPolicy(1.0d);
                        if (ExpertIteration.this.ceExplore) {
                            FVector computeDistribution = softmaxPolicy3.computeDistribution(context, fastArrayList, false);
                            FVector copy = computePiBar.copy();
                            copy.mult(1.0f - ExpertIteration.this.ceExploreMix);
                            copy.addScaled(computeDistribution, ExpertIteration.this.ceExploreMix);
                            int sampleProportionally = copy.sampleProportionally();
                            move = fastArrayList.get(sampleProportionally);
                            d *= computePiBar.get(sampleProportionally) / copy.get(sampleProportionally);
                            List<TIntArrayList> computeSparseFeatureVectors = prepareFeatureSets[mover].computeSparseFeatureVectors(context, fastArrayList, false);
                            FVector fVector = new FVector(softmaxPolicy3.linearFunction(mover).trainableParams().dim());
                            for (int i5 = 0; i5 < computeSparseFeatureVectors.size(); i5++) {
                                TIntArrayList tIntArrayList2 = computeSparseFeatureVectors.get(i5);
                                for (int i6 = 0; i6 < tIntArrayList2.size(); i6++) {
                                    fVector.addToEntry(tIntArrayList2.getQuick(i6), (-1.0f) * computeDistribution.get(i5));
                                }
                            }
                            TIntArrayList tIntArrayList3 = computeSparseFeatureVectors.get(sampleProportionally);
                            for (int i7 = 0; i7 < tIntArrayList3.size(); i7++) {
                                fVector.addToEntry(tIntArrayList3.getQuick(i7), 1.0f);
                            }
                            arrayList3.add(fVector);
                            FVector computeDistribution2 = softmaxPolicy.computeDistribution(computeSparseFeatureVectors, mover);
                            FVector copy2 = computePiBar.copy();
                            copy2.subtract(computeDistribution2);
                            copy2.abs();
                            tFloatArrayList.add(copy2.sum());
                            tIntArrayList.add(mover);
                        } else {
                            move = fastArrayList.get(computePiBar.sampleProportionally());
                        }
                        ExItExperience generateExItExperience = expertPolicy.generateExItExperience();
                        if (prepareValueFunction != null) {
                            generateExItExperience.setStateFeatureVector(prepareValueFunction.computeStateFeatureVector(context, mover));
                        }
                        ((List) arrayList2.get(mover)).add(generateExItExperience);
                        if (ExpertIteration.this.ceExplore) {
                            generateExItExperience.setWeightCEExplore((float) d);
                        }
                        loadGameFromName.apply(context, move);
                        j++;
                        if (j % ExpertIteration.this.updateWeightsEvery == 0) {
                            for (int i8 = 1; i8 <= count; i8++) {
                                ExItExperience[] sampleExperienceBatch = prepareExperienceBuffers[i8].sampleExperienceBatch(ExpertIteration.this.batchSize);
                                if (sampleExperienceBatch.length != 0) {
                                    ArrayList arrayList4 = new ArrayList(sampleExperienceBatch.length);
                                    ArrayList arrayList5 = new ArrayList(sampleExperienceBatch.length);
                                    new ArrayList(sampleExperienceBatch.length);
                                    ArrayList arrayList6 = new ArrayList(sampleExperienceBatch.length);
                                    int[] iArr = new int[sampleExperienceBatch.length];
                                    float[] fArr = new float[sampleExperienceBatch.length];
                                    double d2 = 0.0d;
                                    for (int i9 = 0; i9 < sampleExperienceBatch.length; i9++) {
                                        ExItExperience exItExperience = sampleExperienceBatch[i9];
                                        List<TIntArrayList> computeSparseFeatureVectors2 = prepareFeatureSets[i8].computeSparseFeatureVectors(exItExperience.state().state(), exItExperience.state().lastDecisionMove(), exItExperience.moves(), false);
                                        FVector computeDistributionErrors = softmaxPolicy.computeDistributionErrors(softmaxPolicy.computeDistribution(computeSparseFeatureVectors2, exItExperience.state().state().mover()), exItExperience.expertDistribution());
                                        if (exItExperience.state().state().mover() != i8) {
                                            System.err.println("Sample's mover not equal to p!");
                                        }
                                        FVector computeParamGradients = softmaxPolicy.computeParamGradients(computeDistributionErrors, computeSparseFeatureVectors2, exItExperience.state().state().mover());
                                        FVector fVector2 = null;
                                        if (prepareValueFunction != null) {
                                            FVector paramsVector = prepareValueFunction.paramsVector();
                                            float tanh = (float) Math.tanh(paramsVector.dot(exItExperience.stateFeatureVector()));
                                            float f = tanh - ((float) exItExperience.playerOutcomes()[exItExperience.state().state().mover()]);
                                            fVector2 = new FVector(paramsVector.dim());
                                            float f2 = 2.0f * f * (1.0f - (tanh * tanh));
                                            for (int i10 = 0; i10 < fVector2.dim(); i10++) {
                                                fVector2.set(i10, f2 * exItExperience.stateFeatureVector().get(i10));
                                            }
                                        }
                                        double movingAvg = ExpertIteration.this.importanceSamplingEpisodeDurations ? 1.0d * (prepareGameDurationTrackers[i8].movingAvg() / exItExperience.episodeDuration()) : 1.0d;
                                        if (ExpertIteration.this.prioritizedExperienceReplay) {
                                            FVector copy3 = computeDistributionErrors.copy();
                                            copy3.abs();
                                            fArr[i9] = Math.max(0.05f, copy3.sum());
                                            movingAvg *= exItExperience.weightPER();
                                            iArr[i9] = exItExperience.bufferIdx();
                                        }
                                        if (ExpertIteration.this.ceExplore && !ExpertIteration.this.noCEExploreIS) {
                                            float weightCEExplore = exItExperience.weightCEExplore();
                                            if (weightCEExplore < 0.1f) {
                                                weightCEExplore = 0.1f;
                                            } else if (weightCEExplore > 2.0f) {
                                                weightCEExplore = 2.0f;
                                            }
                                            movingAvg *= weightCEExplore;
                                        }
                                        d2 += movingAvg;
                                        computeParamGradients.mult((float) movingAvg);
                                        arrayList4.add(computeParamGradients);
                                        if (fVector2 != null) {
                                            fVector2.mult((float) movingAvg);
                                            arrayList6.add(fVector2);
                                        }
                                        if (ExpertIteration.this.trainTSPG) {
                                            FVector computeDistribution3 = softmaxPolicy2.computeDistribution(computeSparseFeatureVectors2, exItExperience.state().state().mover());
                                            FVector expertValueEstimates = exItExperience.expertValueEstimates();
                                            FVector fVector3 = new FVector(prepareTSPGFunctions[i8].trainableParams().dim());
                                            for (int i11 = 0; i11 < exItExperience.moves().size(); i11++) {
                                                float f3 = expertValueEstimates.get(i11);
                                                float f4 = computeDistribution3.get(i11);
                                                for (int i12 = 0; i12 < exItExperience.moves().size(); i12++) {
                                                    TIntArrayList tIntArrayList4 = computeSparseFeatureVectors2.get(i12);
                                                    for (int i13 = 0; i13 < tIntArrayList4.size(); i13++) {
                                                        int quick = tIntArrayList4.getQuick(i13);
                                                        if (i11 == i12) {
                                                            fVector3.addToEntry(quick, f3 * f4 * (1.0f - f4));
                                                        } else {
                                                            fVector3.addToEntry(quick, f3 * f4 * (0.0f - computeDistribution3.get(i12)));
                                                        }
                                                    }
                                                }
                                            }
                                            arrayList5.add(fVector3);
                                        }
                                    }
                                    FVector fVector4 = null;
                                    if (ExpertIteration.this.weightedImportanceSampling) {
                                        mean = ((FVector) arrayList4.get(0)).copy();
                                        for (int i14 = 1; i14 < arrayList4.size(); i14++) {
                                            mean.add((FVector) arrayList4.get(i14));
                                        }
                                        mean.div((float) d2);
                                        if (!arrayList6.isEmpty()) {
                                            fVector4 = ((FVector) arrayList6.get(0)).copy();
                                            for (int i15 = 1; i15 < arrayList6.size(); i15++) {
                                                fVector4.add((FVector) arrayList6.get(i15));
                                            }
                                            fVector4.div((float) d2);
                                        }
                                    } else {
                                        mean = FVector.mean(arrayList4);
                                        if (!arrayList6.isEmpty()) {
                                            fVector4 = FVector.mean(arrayList6);
                                        }
                                    }
                                    prepareCrossEntropyOptimisers[i8].minimiseObjective(prepareCrossEntropyFunctions[i8].trainableParams(), mean);
                                    if (fVector4 != null) {
                                        FVector paramsVector2 = prepareValueFunction.paramsVector();
                                        prepareValueFunctionOptimiser.minimiseObjective(paramsVector2, fVector4);
                                        prepareValueFunction.updateParams(loadGameFromName, paramsVector2, 0);
                                    }
                                    if (ExpertIteration.this.trainTSPG) {
                                        prepareTSPGOptimisers[i8].maximiseObjective(prepareTSPGFunctions[i8].trainableParams(), FVector.mean(arrayList5));
                                    }
                                    if (ExpertIteration.this.prioritizedExperienceReplay) {
                                        ((PrioritizedReplayBuffer) prepareExperienceBuffers[i8]).setPriorities(iArr, fArr);
                                    }
                                }
                            }
                            j2++;
                        }
                    }
                    if (!this.interrupted) {
                        for (int i16 = 1; i16 <= count; i16++) {
                            Collections.shuffle((List) arrayList2.get(i16), ThreadLocalRandom.current());
                            int size = ((List) arrayList2.get(i16)).size();
                            prepareGameDurationTrackers[i16].observe(size);
                            double[] agentUtilities = AIUtils.agentUtilities(context);
                            for (ExItExperience exItExperience2 : (List) arrayList2.get(i16)) {
                                exItExperience2.setEpisodeDuration(size);
                                exItExperience2.setPlayerOutcomes(agentUtilities);
                                prepareExperienceBuffers[i16].add(exItExperience2);
                            }
                        }
                        if (ExpertIteration.this.ceExplore && !ExpertIteration.this.ceExploreUniform) {
                            ArrayList arrayList7 = new ArrayList(count + 1);
                            arrayList7.add(null);
                            for (int i17 = 1; i17 <= count; i17++) {
                                arrayList7.add(new ArrayList());
                            }
                            for (int i18 = 0; i18 < arrayList3.size(); i18++) {
                                FVector fVector5 = (FVector) arrayList3.get(i18);
                                float f5 = 0.0f;
                                for (int i19 = i18 + 1; i19 < tFloatArrayList.size(); i19++) {
                                    f5 = (float) (f5 + (Math.pow(ExpertIteration.this.ceExploreGamma, i19 - (i18 + 1)) * tFloatArrayList.getQuick(i19)));
                                }
                                fVector5.mult(f5);
                                ((List) arrayList7.get(tIntArrayList.getQuick(i18))).add(fVector5);
                            }
                            for (int i20 = 1; i20 <= count; i20++) {
                                if (((List) arrayList7.get(i20)).size() > 0) {
                                    prepareCEExploreOptimisers[i20].minimiseObjective(prepareCEExploreFunctions[i20].trainableParams(), FVector.mean((List<FVector>) arrayList7.get(i20)));
                                }
                            }
                        }
                    }
                    if (context.trial().over()) {
                        logLine(printWriter, "Finished running game " + (i2 + 1));
                    }
                    i2++;
                }
                saveCheckpoints(i2 + 1, j2, prepareFeatureSets, prepareCrossEntropyFunctions, prepareTSPGFunctions, prepareCEExploreFunctions, prepareValueFunction, prepareExperienceBuffers, prepareCrossEntropyOptimisers, prepareTSPGOptimisers, prepareCEExploreOptimisers, prepareValueFunctionOptimiser, prepareGameDurationTrackers, true);
            }

            private Optimiser[] prepareCrossEntropyOptimisers() {
                Optimiser[] optimiserArr = new Optimiser[count + 1];
                for (int i = 1; i <= count; i++) {
                    Optimiser optimiser = null;
                    this.currentOptimiserCEFilenames[i] = getFilenameLastCheckpoint("OptimiserCE_P" + i, "opt");
                    this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentOptimiserCEFilenames[i], "OptimiserCE_P" + i, "opt"));
                    if (this.currentOptimiserCEFilenames[i] == null) {
                        optimiser = OptimiserFactory.createOptimiser(ExpertIteration.this.crossEntropyOptimiserConfig);
                        logLine(printWriter, "starting with new optimiser for Cross-Entropy");
                    } else {
                        try {
                            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentOptimiserCEFilenames[i])));
                            Throwable th = null;
                            try {
                                try {
                                    optimiser = (Optimiser) objectInputStream.readObject();
                                    if (objectInputStream != null) {
                                        if (0 != 0) {
                                            try {
                                                objectInputStream.close();
                                            } catch (Throwable th2) {
                                                th.addSuppressed(th2);
                                            }
                                        } else {
                                            objectInputStream.close();
                                        }
                                    }
                                } catch (Throwable th3) {
                                    th = th3;
                                    throw th3;
                                    break;
                                }
                            } finally {
                            }
                        } catch (IOException | ClassNotFoundException e) {
                            e.printStackTrace();
                        }
                        logLine(printWriter, "continuing with CE optimiser loaded from " + this.currentOptimiserCEFilenames[i]);
                    }
                    optimiserArr[i] = optimiser;
                }
                return optimiserArr;
            }

            private Optimiser[] prepareTSPGOptimisers() {
                Optimiser[] optimiserArr = new Optimiser[count + 1];
                if (ExpertIteration.this.trainTSPG) {
                    for (int i = 1; i <= count; i++) {
                        Optimiser optimiser = null;
                        this.currentOptimiserTSPGFilenames[i] = getFilenameLastCheckpoint("OptimiserTSPG_P" + i, "opt");
                        this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentOptimiserTSPGFilenames[i], "OptimiserTSPG_P" + i, "opt"));
                        if (this.currentOptimiserTSPGFilenames[i] == null) {
                            optimiser = OptimiserFactory.createOptimiser(ExpertIteration.this.tspgOptimiserConfig);
                            logLine(printWriter, "starting with new optimiser for TSPG");
                        } else {
                            try {
                                ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentOptimiserTSPGFilenames[i])));
                                Throwable th = null;
                                try {
                                    try {
                                        optimiser = (Optimiser) objectInputStream.readObject();
                                        if (objectInputStream != null) {
                                            if (0 != 0) {
                                                try {
                                                    objectInputStream.close();
                                                } catch (Throwable th2) {
                                                    th.addSuppressed(th2);
                                                }
                                            } else {
                                                objectInputStream.close();
                                            }
                                        }
                                    } catch (Throwable th3) {
                                        th = th3;
                                        throw th3;
                                        break;
                                    }
                                } finally {
                                }
                            } catch (IOException | ClassNotFoundException e) {
                                e.printStackTrace();
                            }
                            logLine(printWriter, "continuing with TSPG optimiser loaded from " + this.currentOptimiserTSPGFilenames[i]);
                        }
                        optimiserArr[i] = optimiser;
                    }
                }
                return optimiserArr;
            }

            private Optimiser[] prepareCEExploreOptimisers() {
                Optimiser[] optimiserArr = new Optimiser[count + 1];
                for (int i = 1; i <= count; i++) {
                    Optimiser optimiser = null;
                    this.currentOptimiserCEEFilenames[i] = getFilenameLastCheckpoint("OptimiserCEE_P" + i, "opt");
                    this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentOptimiserCEEFilenames[i], "OptimiserCEE_P" + i, "opt"));
                    if (this.currentOptimiserCEEFilenames[i] == null) {
                        optimiser = OptimiserFactory.createOptimiser(ExpertIteration.this.ceExploreOptimiserConfig);
                        logLine(printWriter, "starting with new optimiser for CEE");
                    } else {
                        try {
                            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentOptimiserCEEFilenames[i])));
                            Throwable th = null;
                            try {
                                try {
                                    optimiser = (Optimiser) objectInputStream.readObject();
                                    if (objectInputStream != null) {
                                        if (0 != 0) {
                                            try {
                                                objectInputStream.close();
                                            } catch (Throwable th2) {
                                                th.addSuppressed(th2);
                                            }
                                        } else {
                                            objectInputStream.close();
                                        }
                                    }
                                } catch (Throwable th3) {
                                    th = th3;
                                    throw th3;
                                    break;
                                }
                            } finally {
                            }
                        } catch (IOException | ClassNotFoundException e) {
                            e.printStackTrace();
                        }
                        logLine(printWriter, "continuing with CEE optimiser loaded from " + this.currentOptimiserCEEFilenames[i]);
                    }
                    optimiserArr[i] = optimiser;
                }
                return optimiserArr;
            }

            private Optimiser prepareValueFunctionOptimiser() {
                Optimiser[] optimiserArr = new Optimiser[count + 1];
                Optimiser optimiser = null;
                this.currentOptimiserValueFilename = getFilenameLastCheckpoint("OptimiserValue", "opt");
                this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentOptimiserValueFilename, "OptimiserValue", "opt"));
                if (this.currentOptimiserValueFilename == null) {
                    optimiser = OptimiserFactory.createOptimiser(ExpertIteration.this.valueOptimiserConfig);
                    logLine(printWriter, "starting with new optimiser for Value function");
                } else {
                    try {
                        ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentOptimiserValueFilename)));
                        Throwable th = null;
                        try {
                            try {
                                optimiser = (Optimiser) objectInputStream.readObject();
                                if (objectInputStream != null) {
                                    if (0 != 0) {
                                        try {
                                            objectInputStream.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    } else {
                                        objectInputStream.close();
                                    }
                                }
                            } finally {
                            }
                        } finally {
                        }
                    } catch (IOException | ClassNotFoundException e) {
                        e.printStackTrace();
                    }
                    logLine(printWriter, "continuing with Value function optimiser loaded from " + this.currentOptimiserValueFilename);
                }
                return optimiser;
            }

            private ExperienceBuffer[] prepareExperienceBuffers(boolean z) {
                ExperienceBuffer fromFile;
                ExperienceBuffer[] experienceBufferArr = new ExperienceBuffer[count + 1];
                for (int i = 1; i <= count; i++) {
                    this.currentExperienceBufferFilenames[i] = getFilenameLastCheckpoint("ExperienceBuffer_P" + i, "buf");
                    this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentExperienceBufferFilenames[i], "ExperienceBuffer_P" + i, "buf"));
                    if (this.currentExperienceBufferFilenames[i] == null) {
                        fromFile = z ? new PrioritizedReplayBuffer(ExpertIteration.this.experienceBufferSize) : new UniformExperienceBuffer(ExpertIteration.this.experienceBufferSize);
                        logLine(printWriter, "starting with empty experience buffer");
                    } else {
                        fromFile = z ? PrioritizedReplayBuffer.fromFile(loadGameFromName, ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentExperienceBufferFilenames[i]) : UniformExperienceBuffer.fromFile(loadGameFromName, ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentExperienceBufferFilenames[i]);
                        logLine(printWriter, "continuing with experience buffer loaded from " + this.currentExperienceBufferFilenames[i]);
                    }
                    experienceBufferArr[i] = fromFile;
                }
                return experienceBufferArr;
            }

            private ExponentialMovingAverage[] prepareGameDurationTrackers() {
                ObjectInputStream objectInputStream;
                Throwable th;
                ExponentialMovingAverage[] exponentialMovingAverageArr = new ExponentialMovingAverage[count + 1];
                for (int i = 1; i <= count; i++) {
                    ExponentialMovingAverage exponentialMovingAverage = null;
                    this.currentGameDurationTrackerFilenames[i] = getFilenameLastCheckpoint("GameDurationTracker_P" + i, "bin");
                    this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentGameDurationTrackerFilenames[i], "GameDurationTracker_P" + i, "bin"));
                    if (this.currentGameDurationTrackerFilenames[i] == null) {
                        exponentialMovingAverage = new ExponentialMovingAverage();
                        logLine(printWriter, "starting with new tracker for average game duration");
                    } else {
                        try {
                            objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentGameDurationTrackerFilenames[i])));
                            th = null;
                        } catch (IOException | ClassNotFoundException e) {
                            e.printStackTrace();
                        }
                        try {
                            try {
                                exponentialMovingAverage = (ExponentialMovingAverage) objectInputStream.readObject();
                                if (objectInputStream != null) {
                                    if (0 != 0) {
                                        try {
                                            objectInputStream.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    } else {
                                        objectInputStream.close();
                                    }
                                }
                                logLine(printWriter, "continuing with average game duration tracker loaded from " + this.currentGameDurationTrackerFilenames[i]);
                            } catch (Throwable th3) {
                                th = th3;
                                throw th3;
                                break;
                            }
                        } finally {
                        }
                    }
                    exponentialMovingAverageArr[i] = exponentialMovingAverage;
                }
                return exponentialMovingAverageArr;
            }

            private LinearFunction[] prepareCrossEntropyFunctions(FeatureSet[] featureSetArr) {
                LinearFunction fromFile;
                LinearFunction[] linearFunctionArr = new LinearFunction[count + 1];
                for (int i = 1; i <= count; i++) {
                    this.currentPolicyWeightsCEFilenames[i] = getFilenameLastCheckpoint("PolicyWeightsCE_P" + i, "txt");
                    this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentPolicyWeightsCEFilenames[i], "PolicyWeightsCE_P" + i, "txt"));
                    if (this.currentPolicyWeightsCEFilenames[i] == null) {
                        fromFile = new LinearFunction(new FVector(featureSetArr[i].getNumFeatures()));
                        logLine(printWriter, "starting with new 0-weights linear function for Cross-Entropy");
                    } else {
                        fromFile = LinearFunction.fromFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentPolicyWeightsCEFilenames[i]);
                        logLine(printWriter, "continuing with Selection policy weights loaded from " + this.currentPolicyWeightsCEFilenames[i]);
                        try {
                            String str = new File(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentPolicyWeightsCEFilenames[i]).getParent() + File.separator + fromFile.featureSetFile();
                            if (!new File(str).getCanonicalPath().equals(new File(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentFeatureSetFilenames[i]).getCanonicalPath())) {
                                System.err.println("Warning: policy weights were saved for feature set " + str + ", but we are now using " + this.currentFeatureSetFilenames[i]);
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                    linearFunctionArr[i] = fromFile;
                }
                return linearFunctionArr;
            }

            private LinearFunction[] prepareTSPGFunctions(FeatureSet[] featureSetArr, LinearFunction[] linearFunctionArr) {
                BoostedLinearFunction boostedFromFile;
                LinearFunction[] linearFunctionArr2 = new LinearFunction[count + 1];
                if (ExpertIteration.this.trainTSPG) {
                    for (int i = 1; i <= count; i++) {
                        this.currentPolicyWeightsTSPGFilenames[i] = getFilenameLastCheckpoint("PolicyWeightsTSPG_P" + i, "txt");
                        this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentPolicyWeightsTSPGFilenames[i], "PolicyWeightsTSPG_P" + i, "txt"));
                        if (this.currentPolicyWeightsTSPGFilenames[i] == null) {
                            boostedFromFile = new BoostedLinearFunction(new FVector(featureSetArr[i].getNumFeatures()), linearFunctionArr[i]);
                            logLine(printWriter, "starting with new 0-weights linear function for TSPG");
                        } else {
                            boostedFromFile = BoostedLinearFunction.boostedFromFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentPolicyWeightsTSPGFilenames[i], linearFunctionArr[i]);
                            logLine(printWriter, "continuing with Selection policy weights loaded from " + this.currentPolicyWeightsTSPGFilenames[i]);
                            try {
                                String str = new File(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentPolicyWeightsTSPGFilenames[i]).getParent() + File.separator + boostedFromFile.featureSetFile();
                                if (!new File(str).getCanonicalPath().equals(new File(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentFeatureSetFilenames[i]).getCanonicalPath())) {
                                    System.err.println("Warning: policy weights were saved for feature set " + str + ", but we are now using " + this.currentFeatureSetFilenames[i]);
                                }
                            } catch (IOException e) {
                                e.printStackTrace();
                            }
                        }
                        linearFunctionArr2[i] = boostedFromFile;
                    }
                }
                return linearFunctionArr2;
            }

            private LinearFunction[] prepareCEExploreFunctions(FeatureSet[] featureSetArr) {
                LinearFunction fromFile;
                LinearFunction[] linearFunctionArr = new LinearFunction[count + 1];
                for (int i = 1; i <= count; i++) {
                    this.currentPolicyWeightsCEEFilenames[i] = getFilenameLastCheckpoint("PolicyWeightsCEE_P" + i, "txt");
                    this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentPolicyWeightsCEEFilenames[i], "PolicyWeightsCEE_P" + i, "txt"));
                    if (this.currentPolicyWeightsCEEFilenames[i] == null) {
                        fromFile = new LinearFunction(new FVector(featureSetArr[i].getNumFeatures()));
                        logLine(printWriter, "starting with new 0-weights linear function for Cross-Entropy Exploration");
                    } else {
                        fromFile = LinearFunction.fromFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentPolicyWeightsCEEFilenames[i]);
                        logLine(printWriter, "continuing with Selection policy weights loaded from " + this.currentPolicyWeightsCEEFilenames[i]);
                        try {
                            String str = new File(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentPolicyWeightsCEEFilenames[i]).getParent() + File.separator + fromFile.featureSetFile();
                            if (!new File(str).getCanonicalPath().equals(new File(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentFeatureSetFilenames[i]).getCanonicalPath())) {
                                System.err.println("Warning: CE Exploration policy weights were saved for feature set " + str + ", but we are now using " + this.currentFeatureSetFilenames[i]);
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                    linearFunctionArr[i] = fromFile;
                }
                return linearFunctionArr;
            }

            private Heuristics prepareValueFunction() {
                if (ExpertIteration.this.noValueLearning) {
                    return null;
                }
                Heuristics heuristics = null;
                this.currentValueFunctionFilename = getFilenameLastCheckpoint("ValueFunction", "txt");
                this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentValueFunctionFilename, "ValueFunction", "txt"));
                Report report = new Report();
                if (this.currentValueFunctionFilename != null) {
                    try {
                        heuristics = (Heuristics) Compiler.compileObject(FileHandling.loadTextContentsFromFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentValueFunctionFilename), "metadata.ai.heuristics.Heuristics", report);
                        heuristics.init(loadGameFromName);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    logLine(printWriter, "continuing with value function from " + ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentValueFunctionFilename);
                } else if (ExpertIteration.this.bestAgentsDataDir != null) {
                    try {
                        heuristics = (Heuristics) Compiler.compileObject(FileHandling.loadTextContentsFromFile(ExpertIteration.this.bestAgentsDataDir + "/BestHeuristics.txt"), "metadata.ai.heuristics.Heuristics", report);
                        heuristics.init(loadGameFromName);
                    } catch (IOException e2) {
                        e2.printStackTrace();
                    }
                } else {
                    heuristics = loadGameFromName.metadata().ai().heuristics();
                    heuristics.init(loadGameFromName);
                    logLine(printWriter, "starting with new initial value function from .lud metadata");
                }
                return heuristics;
            }

            /* JADX WARN: Multi-variable type inference failed */
            private FeatureSet[] prepareFeatureSets() {
                FeatureSet featureSet;
                FeatureSet[] featureSetArr = new FeatureSet[count + 1];
                TIntArrayList tIntArrayList = new TIntArrayList();
                for (int i = 1; i <= count; i++) {
                    this.currentFeatureSetFilenames[i] = getFilenameLastCheckpoint("FeatureSet_P" + i, "fs");
                    this.lastCheckpoint = Math.min(this.lastCheckpoint, extractCheckpointFromFilename(this.currentFeatureSetFilenames[i], "FeatureSet_P" + i, "fs"));
                    if (this.currentFeatureSetFilenames[i] == null) {
                        featureSet = new FeatureSet(new AtomicFeatureGenerator(loadGameFromName, 2, 4).getFeatures());
                        tIntArrayList.add(i);
                        logLine(printWriter, "starting with new initial feature set for Player " + i);
                        logLine(printWriter, "num atomic features = " + featureSet.getNumFeatures());
                    } else {
                        featureSet = new FeatureSet(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentFeatureSetFilenames[i]);
                        logLine(printWriter, "continuing with feature set loaded from " + ExpertIteration.this.outDir.getAbsolutePath() + File.separator + this.currentFeatureSetFilenames[i] + " for Player " + i);
                    }
                    if (featureSet.getNumFeatures() == 0) {
                        System.err.println("ERROR: Feature Set has 0 features!");
                        logLine(printWriter, "Training with 0 features makes no sense, interrupting experiment.");
                        this.interrupted = true;
                    }
                    featureSet.instantiateFeatures(loadGameFromName, new int[]{i}, null);
                    featureSetArr[i] = featureSet;
                }
                if (tIntArrayList.size() > 0) {
                    long[][] jArr = new long[count + 1];
                    for (int i2 = 1; i2 <= count; i2++) {
                        int numFeatures = featureSetArr[i2].getNumFeatures();
                        jArr[i2] = new long[numFeatures][numFeatures];
                    }
                    Context context = new Context(loadGameFromName, new Trial(loadGameFromName));
                    long currentTimeMillis = System.currentTimeMillis() + (ExpertIteration.this.maxNumPruningSeconds * 1000);
                    for (int i3 = 0; i3 < ExpertIteration.this.numPruningGames && System.currentTimeMillis() <= currentTimeMillis; i3++) {
                        loadGameFromName.start(context);
                        int i4 = 0;
                        while (!context.trial().over()) {
                            FastArrayList<Move> moves = loadGameFromName.moves(context).moves();
                            int mover = context.state().mover();
                            if (tIntArrayList.contains(mover)) {
                                for (TIntArrayList tIntArrayList2 : featureSetArr[mover].computeSparseFeatureVectors(context, moves, false)) {
                                    for (int i5 = 0; i5 < tIntArrayList2.size(); i5++) {
                                        int quick = tIntArrayList2.getQuick(i5);
                                        long[] jArr2 = jArr[mover][quick];
                                        jArr2[quick] = jArr2[quick] + 1;
                                        for (int i6 = i5 + 1; i6 < tIntArrayList2.size(); i6++) {
                                            int quick2 = tIntArrayList2.getQuick(i6);
                                            long[] jArr3 = jArr[mover][quick];
                                            jArr3[quick2] = jArr3[quick2] + 1;
                                            long[] jArr4 = jArr[mover][quick2];
                                            jArr4[quick] = jArr4[quick] + 1;
                                        }
                                    }
                                }
                            }
                            loadGameFromName.apply(context, moves.get(ThreadLocalRandom.current().nextInt(moves.size())));
                            i4++;
                        }
                    }
                    for (int i7 = 0; i7 < tIntArrayList.size(); i7++) {
                        int quick3 = tIntArrayList.getQuick(i7);
                        TIntHashSet tIntHashSet = new TIntHashSet();
                        FeatureSet featureSet2 = featureSetArr[quick3];
                        int numFeatures2 = featureSet2.getNumFeatures();
                        for (int i8 = 0; i8 < numFeatures2; i8++) {
                            if (!tIntHashSet.contains(i8)) {
                                long j = jArr[quick3][i8][i8];
                                if (j >= ExpertIteration.this.pruneInitFeaturesThreshold) {
                                    for (int i9 = i8 + 1; i9 < numFeatures2; i9++) {
                                        if (!tIntHashSet.contains(i9) && j == jArr[quick3][i8][i9] && j == jArr[quick3][i9][i9]) {
                                            Feature feature = featureSet2.features()[i8];
                                            Feature feature2 = featureSet2.features()[i9];
                                            Pattern pattern = feature.pattern();
                                            Pattern pattern2 = feature2.pattern();
                                            boolean z = true;
                                            if (pattern2.featureElements().size() < pattern.featureElements().size()) {
                                                z = false;
                                            } else {
                                                int i10 = 0;
                                                for (FeatureElement featureElement : pattern.featureElements()) {
                                                    if (featureElement instanceof RelativeFeatureElement) {
                                                        i10 += ((RelativeFeatureElement) featureElement).walk().steps().size();
                                                    }
                                                }
                                                int i11 = 0;
                                                for (FeatureElement featureElement2 : pattern2.featureElements()) {
                                                    if (featureElement2 instanceof RelativeFeatureElement) {
                                                        i11 += ((RelativeFeatureElement) featureElement2).walk().steps().size();
                                                    }
                                                }
                                                if (i11 < i10) {
                                                    z = false;
                                                }
                                            }
                                            if (z) {
                                                tIntHashSet.add(i9);
                                            } else {
                                                tIntHashSet.add(i8);
                                            }
                                        }
                                    }
                                }
                            }
                        }
                        ArrayList arrayList = new ArrayList();
                        for (int i12 = 0; i12 < numFeatures2; i12++) {
                            if (!tIntHashSet.contains(i12)) {
                                arrayList.add(featureSet2.features()[i12]);
                            }
                        }
                        FeatureSet featureSet3 = new FeatureSet(arrayList);
                        featureSet3.instantiateFeatures(loadGameFromName, new int[]{quick3}, null);
                        featureSetArr[quick3] = featureSet3;
                        logLine(printWriter, "Finished pruning atomic feature set for Player " + quick3);
                        logLine(printWriter, "Num atomic features after pruning = " + featureSet3.getNumFeatures());
                    }
                }
                return featureSetArr;
            }

            public FeatureSet expandFeatureSetCorrelationBased(ExItExperience[] exItExperienceArr, FeatureSet featureSet, final SoftmaxPolicy softmaxPolicy, Game game2, int i) {
                int i2 = 0;
                TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap(10, 0.5f, 0);
                TObjectDoubleHashMap tObjectDoubleHashMap = new TObjectDoubleHashMap(10, 0.5f, 0.0d);
                double d = 0.0d;
                double d2 = 0.0d;
                new HashMap();
                new HashMap();
                for (final ExItExperience exItExperience : exItExperienceArr) {
                    FVector computeDistributionErrors = softmaxPolicy.computeDistributionErrors(softmaxPolicy.computeDistribution(featureSet.computeSparseFeatureVectors(exItExperience.state().state(), exItExperience.state().lastDecisionMove(), exItExperience.moves(), false), exItExperience.state().state().mover()), exItExperience.expertDistribution());
                    HashSet hashSet = new HashSet((int) Math.ceil(featureSet.getNumFeatures() / 0.75f), 0.75f);
                    for (Feature feature : featureSet.features()) {
                        hashSet.add(feature);
                    }
                    for (int i3 = 0; i3 < exItExperience.moves().size(); i3++) {
                        i2++;
                        HashSet hashSet2 = new HashSet(256, 0.75f);
                        List<FeatureInstance> activeFeatureInstances = featureSet.getActiveFeatureInstances(exItExperience.state().state(), FeatureUtils.fromPos(exItExperience.state().lastDecisionMove()), FeatureUtils.toPos(exItExperience.state().lastDecisionMove()), FeatureUtils.fromPos(exItExperience.moves().get(i3)), FeatureUtils.toPos(exItExperience.moves().get(i3)), exItExperience.moves().get(i3).mover());
                        if (activeFeatureInstances.size() > i) {
                            activeFeatureInstances.sort(new Comparator<FeatureInstance>() { // from class: expert_iteration.ExpertIteration.1.1
                                @Override // java.util.Comparator
                                public int compare(FeatureInstance featureInstance, FeatureInstance featureInstance2) {
                                    int featureSetIndex = featureInstance.feature().featureSetIndex();
                                    int featureSetIndex2 = featureInstance2.feature().featureSetIndex();
                                    float abs = Math.abs(softmaxPolicy.linearFunction(exItExperience.state().state().mover()).effectiveParams().get(featureSetIndex));
                                    float abs2 = Math.abs(softmaxPolicy.linearFunction(exItExperience.state().state().mover()).effectiveParams().get(featureSetIndex2));
                                    if (abs == abs2) {
                                        return 0;
                                    }
                                    return abs > abs2 ? -1 : 1;
                                }
                            });
                            activeFeatureInstances = activeFeatureInstances.subList(0, i);
                        }
                        int size = activeFeatureInstances.size();
                        float f = computeDistributionErrors.get(i3);
                        d += f;
                        d2 += f * f;
                        for (int i4 = 0; i4 < size; i4++) {
                            FeatureInstance featureInstance = activeFeatureInstances.get(i4);
                            CombinableFeatureInstancePair combinableFeatureInstancePair = new CombinableFeatureInstancePair(game2, featureInstance, featureInstance);
                            if (!hashSet2.contains(combinableFeatureInstancePair)) {
                                tObjectIntHashMap.put(combinableFeatureInstancePair, tObjectIntHashMap.get(combinableFeatureInstancePair) + 1);
                                tObjectDoubleHashMap.put(combinableFeatureInstancePair, tObjectDoubleHashMap.get(combinableFeatureInstancePair) + f);
                                hashSet2.add(combinableFeatureInstancePair);
                            }
                            for (int i5 = i4 + 1; i5 < size; i5++) {
                                CombinableFeatureInstancePair combinableFeatureInstancePair2 = new CombinableFeatureInstancePair(game2, featureInstance, activeFeatureInstances.get(i5));
                                if (!hashSet.contains(combinableFeatureInstancePair2.combinedFeature) && !hashSet2.contains(combinableFeatureInstancePair2)) {
                                    tObjectIntHashMap.put(combinableFeatureInstancePair2, tObjectIntHashMap.get(combinableFeatureInstancePair2) + 1);
                                    tObjectDoubleHashMap.put(combinableFeatureInstancePair2, tObjectDoubleHashMap.get(combinableFeatureInstancePair2) + f);
                                    hashSet2.add(combinableFeatureInstancePair2);
                                }
                            }
                        }
                    }
                }
                if (d == 0.0d || d2 == 0.0d) {
                    return null;
                }
                ArrayList arrayList = new ArrayList(tObjectIntHashMap.size());
                double d3 = Double.NEGATIVE_INFINITY;
                int i6 = -1;
                for (CombinableFeatureInstancePair combinableFeatureInstancePair3 : tObjectIntHashMap.keySet()) {
                    if (!combinableFeatureInstancePair3.a.equals(combinableFeatureInstancePair3.b)) {
                        int i7 = tObjectIntHashMap.get(new CombinableFeatureInstancePair(game2, combinableFeatureInstancePair3.a, combinableFeatureInstancePair3.a));
                        int i8 = tObjectIntHashMap.get(new CombinableFeatureInstancePair(game2, combinableFeatureInstancePair3.b, combinableFeatureInstancePair3.b));
                        int i9 = tObjectIntHashMap.get(combinableFeatureInstancePair3);
                        if (i9 != i2 && i7 != i2 && i8 != i2) {
                            double abs = Math.abs(((i2 * tObjectDoubleHashMap.get(combinableFeatureInstancePair3)) - (i9 * d)) / (Math.sqrt(i9 * (i2 - i9)) * Math.sqrt((i2 * d2) - (d * d)))) * (1.0d - Math.max(Math.abs((i9 * (i2 - i7)) / (Math.sqrt(i9 * (i2 - i9)) * Math.sqrt(i7 * (i2 - i7)))), Math.abs((i9 * (i2 - i8)) / (Math.sqrt(i9 * (i2 - i9)) * Math.sqrt(i8 * (i2 - i8))))));
                            if (!Double.isNaN(abs)) {
                                arrayList.add(new ScoredPair(combinableFeatureInstancePair3, abs));
                                if (abs > d3) {
                                    d3 = abs;
                                    i6 = arrayList.size() - 1;
                                }
                            }
                        }
                    }
                }
                while (arrayList.size() > 0) {
                    ScoredPair scoredPair = (ScoredPair) arrayList.remove(i6);
                    FeatureSet createExpandedFeatureSet = featureSet.createExpandedFeatureSet(game2, scoredPair.pair.a, scoredPair.pair.b);
                    if (createExpandedFeatureSet != null) {
                        int i10 = tObjectIntHashMap.get(new CombinableFeatureInstancePair(game2, scoredPair.pair.a, scoredPair.pair.a));
                        int i11 = tObjectIntHashMap.get(new CombinableFeatureInstancePair(game2, scoredPair.pair.b, scoredPair.pair.b));
                        int i12 = tObjectIntHashMap.get(new CombinableFeatureInstancePair(game2, scoredPair.pair.a, scoredPair.pair.b));
                        double sqrt = ((i2 * tObjectDoubleHashMap.get(new CombinableFeatureInstancePair(game2, scoredPair.pair.a, scoredPair.pair.b))) - (i12 * d)) / (Math.sqrt((i2 * i12) - (i12 * i12)) * Math.sqrt((i2 * d2) - (d * d)));
                        double sqrt2 = ((i2 * i12) - (i12 * i10)) / (Math.sqrt((i2 * i12) - (i12 * i12)) * Math.sqrt((i2 * i10) - (i10 * i10)));
                        double sqrt3 = ((i2 * i12) - (i12 * i11)) / (Math.sqrt((i2 * i12) - (i12 * i12)) * Math.sqrt((i2 * i11) - (i11 * i11)));
                        logLine(printWriter, "New feature added!");
                        logLine(printWriter, "new feature = " + createExpandedFeatureSet.features()[createExpandedFeatureSet.getNumFeatures() - 1]);
                        logLine(printWriter, "active feature A = " + scoredPair.pair.a.feature());
                        logLine(printWriter, "rot A = " + scoredPair.pair.a.rotation());
                        logLine(printWriter, "ref A = " + scoredPair.pair.a.reflection());
                        logLine(printWriter, "anchor A = " + scoredPair.pair.a.anchorSite());
                        logLine(printWriter, "active feature B = " + scoredPair.pair.b.feature());
                        logLine(printWriter, "rot B = " + scoredPair.pair.b.rotation());
                        logLine(printWriter, "ref B = " + scoredPair.pair.b.reflection());
                        logLine(printWriter, "anchor B = " + scoredPair.pair.b.anchorSite());
                        logLine(printWriter, "score = " + scoredPair.score);
                        logLine(printWriter, "correlation with errors = " + sqrt);
                        logLine(printWriter, "correlation with first constituent = " + sqrt2);
                        logLine(printWriter, "correlation with second constituent = " + sqrt3);
                        logLine(printWriter, "observed pair of instances " + i12 + " times");
                        logLine(printWriter, "observed first constituent " + i10 + " times");
                        logLine(printWriter, "observed second constituent " + i11 + " times");
                        return createExpandedFeatureSet;
                    }
                    double d4 = Double.NEGATIVE_INFINITY;
                    i6 = -1;
                    for (int i13 = 0; i13 < arrayList.size(); i13++) {
                        if (((ScoredPair) arrayList.get(i13)).score > d4) {
                            d4 = ((ScoredPair) arrayList.get(i13)).score;
                            i6 = i13;
                        }
                    }
                }
                return null;
            }

            private long computeNextCheckpoint() {
                if (this.lastCheckpoint < 0) {
                    return 0L;
                }
                return this.lastCheckpoint + ExpertIteration.this.checkpointFrequency;
            }

            private String createCheckpointFilename(String str, long j, String str2) {
                return String.format(ExpertIteration.this.checkpointType == CheckpointTypes.Game ? ExpertIteration.gameCheckpointFormat : ExpertIteration.weightUpdateCheckpointFormat, str, Long.valueOf(j), str2);
            }

            private int extractCheckpointFromFilename(String str, String str2, String str3) {
                if (str == null) {
                    return -1;
                }
                return Integer.parseInt(str.substring((str2 + "_").length(), str.length() - ("." + str3).length()));
            }

            private String getFilenameLastCheckpoint(String str, String str2) {
                int extractCheckpointFromFilename;
                if (ExpertIteration.this.outDir == null) {
                    return null;
                }
                int i = -1;
                for (String str3 : ExpertIteration.this.outDir.list()) {
                    if (str3.startsWith(str + "_") && str3.endsWith("." + str2) && (extractCheckpointFromFilename = extractCheckpointFromFilename(str3, str, str2)) > i) {
                        i = extractCheckpointFromFilename;
                    }
                }
                if (i < 0) {
                    return null;
                }
                return createCheckpointFilename(str, i, str2);
            }

            private void saveCheckpoints(int i, long j, FeatureSet[] featureSetArr, LinearFunction[] linearFunctionArr, LinearFunction[] linearFunctionArr2, LinearFunction[] linearFunctionArr3, Heuristics heuristics, ExperienceBuffer[] experienceBufferArr, Optimiser[] optimiserArr, Optimiser[] optimiserArr2, Optimiser[] optimiserArr3, Optimiser optimiser, ExponentialMovingAverage[] exponentialMovingAverageArr, boolean z) {
                long computeNextCheckpoint = computeNextCheckpoint();
                if (ExpertIteration.this.checkpointType == CheckpointTypes.Game) {
                    if (!z && i < computeNextCheckpoint) {
                        return;
                    } else {
                        computeNextCheckpoint = i;
                    }
                } else if (ExpertIteration.this.checkpointType == CheckpointTypes.WeightUpdate) {
                    if (!z && j < computeNextCheckpoint) {
                        return;
                    } else {
                        computeNextCheckpoint = j;
                    }
                }
                for (int i2 = 1; i2 <= count; i2++) {
                    String createCheckpointFilename = createCheckpointFilename("FeatureSet_P" + i2, computeNextCheckpoint, "fs");
                    featureSetArr[i2].toFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename);
                    this.currentFeatureSetFilenames[i2] = createCheckpointFilename;
                    String createCheckpointFilename2 = createCheckpointFilename("PolicyWeightsCE_P" + i2, computeNextCheckpoint, "txt");
                    linearFunctionArr[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename2, new String[]{this.currentFeatureSetFilenames[i2]});
                    this.currentPolicyWeightsCEFilenames[i2] = createCheckpointFilename2;
                    if (ExpertIteration.this.trainTSPG) {
                        String createCheckpointFilename3 = createCheckpointFilename("PolicyWeightsTSPG_P" + i2, computeNextCheckpoint, "txt");
                        linearFunctionArr2[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename3, new String[]{this.currentFeatureSetFilenames[i2]});
                        this.currentPolicyWeightsTSPGFilenames[i2] = createCheckpointFilename3;
                    }
                    if (ExpertIteration.this.ceExplore && !ExpertIteration.this.ceExploreUniform) {
                        String createCheckpointFilename4 = createCheckpointFilename("PolicyWeightsCEE_P" + i2, computeNextCheckpoint, "txt");
                        linearFunctionArr3[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename4, new String[]{this.currentFeatureSetFilenames[i2]});
                        this.currentPolicyWeightsCEEFilenames[i2] = createCheckpointFilename4;
                    }
                    if (heuristics != null) {
                        heuristics.toFile(loadGameFromName, ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename("ValueFunction", computeNextCheckpoint, "txt"));
                    }
                    if (z) {
                        experienceBufferArr[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename("ExperienceBuffer_P" + i2, computeNextCheckpoint, "buf"));
                        String createCheckpointFilename5 = createCheckpointFilename("OptimiserCE_P" + i2, computeNextCheckpoint, "opt");
                        optimiserArr[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename5);
                        this.currentOptimiserCEFilenames[i2] = createCheckpointFilename5;
                        if (ExpertIteration.this.trainTSPG) {
                            String createCheckpointFilename6 = createCheckpointFilename("OptimiserTSPG_P" + i2, computeNextCheckpoint, "opt");
                            optimiserArr2[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename6);
                            this.currentOptimiserTSPGFilenames[i2] = createCheckpointFilename6;
                        }
                        if (ExpertIteration.this.ceExplore && !ExpertIteration.this.ceExploreUniform) {
                            String createCheckpointFilename7 = createCheckpointFilename("OptimiserCEE_P" + i2, computeNextCheckpoint, "opt");
                            optimiserArr3[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename7);
                            this.currentOptimiserCEEFilenames[i2] = createCheckpointFilename7;
                        }
                        String createCheckpointFilename8 = createCheckpointFilename("GameDurationTracker_P" + i2, computeNextCheckpoint, "bin");
                        exponentialMovingAverageArr[i2].writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename8);
                        this.currentGameDurationTrackerFilenames[i2] = createCheckpointFilename8;
                    }
                }
                if (z) {
                    String createCheckpointFilename9 = createCheckpointFilename("OptimiserValue", computeNextCheckpoint, "opt");
                    optimiser.writeToFile(ExpertIteration.this.outDir.getAbsolutePath() + File.separator + createCheckpointFilename9);
                    this.currentOptimiserValueFilename = createCheckpointFilename9;
                }
                this.lastCheckpoint = computeNextCheckpoint;
            }

            @Override // utils.experiments.InterruptableExperiment
            public void logLine(PrintWriter printWriter2, String str) {
                if (ExpertIteration.this.noLogging) {
                    return;
                }
                super.logLine(printWriter2, str);
            }
        };
    }

    private PrintWriter createLogWriter() {
        if (this.outDir == null || this.noLogging) {
            return null;
        }
        String nextFilepath = ExperimentFileUtils.getNextFilepath(this.outDir.getAbsolutePath() + File.separator + "ExIt", "log");
        new File(nextFilepath).getParentFile().mkdirs();
        try {
            return new PrintWriter(nextFilepath, "UTF-8");
        } catch (FileNotFoundException | UnsupportedEncodingException e) {
            e.printStackTrace();
            return null;
        }
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Execute a training run from self-play using Expert Iteration.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game").help("Name of the game to play. Should end with \".lud\".").withDefault("board/space/blocking/Amazons.lud").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game-options").help("Game Options to load.").withDefault(new ArrayList(0)).withNumVals("*").withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--expert-ai").help("Type of AI to use as expert.").withDefault("BEST_AGENT").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).withLegalVals("BEST_AGENT", "FROM_METADATA", "Biased MCTS"));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--best-agents-data-dir").help("Filepath for directory with best agents data for this game (+ options).").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("-n", "--num-games", "--num-training-games").help("Number of training games to run.").withDefault(200).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game-length-cap", "--max-num-actions").help("Maximum number of actions that may be taken before a game is terminated as a draw (-1 for no limit).").withDefault(-1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--thinking-time", "--time", "--seconds").help("Max allowed thinking time per move (in seconds).").withDefault(Double.valueOf(1.0d)).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Double));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--iteration-limit", "--iterations").help("Max allowed number of MCTS iterations per move.").withDefault(-1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--depth-limit").help("Search depth limit (e.g. for Alpha-Beta experts).").withDefault(-1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--add-feature-every").help("After this many training games, we add a new feature.").withDefault(1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--batch-size").help("Max size of minibatches in training.").withDefault(30).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--buffer-size", "--experience-buffer-size").help("Max size of the experience buffer.").withDefault(2500).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--update-weights-every").help("After this many moves (decision points) in training games, we update weights.").withDefault(1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--no-grow-features", "--no-grow-featureset", "--no-grow-feature-set").help("If true, we'll not grow feature set (but still train weights).").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--train-tspg").help("If true, we'll train a policy on TSPG objective (see COG paper).").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--ce-optimiser", "--cross-entropy-optimiser").help("Optimiser to use for policy trained on Cross-Entropy loss.").withDefault("RMSProp").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--cee-optimiser", "--cross-entropy-exploration-optimiser").help("Optimiser to use for training Cross-Entropy Exploration policy.").withDefault("RMSProp").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--tspg-optimiser").help("Optimiser to use for policy trained on TSPG objective (see COG paper).").withDefault("RMSProp").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--value-optimiser").help("Optimiser to use for value function optimisation.").withDefault("RMSProp").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--combining-feature-instance-threshold").help("At most this number of feature instances will be taken into account when combining features.").withDefault(75).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--is-episode-durations").help("If true, we'll use importance sampling weights based on episode durations for CE-loss.").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--prioritized-experience-replay", "--per").help("If true, we'll use prioritized experience replay").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--ce-explore").help("If true, we'll use extra exploration based on cross-entropy losses").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--ce-explore-mix").help("Proportion of exploration policy in our behaviour mix").withDefault(Float.valueOf(0.1f)).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Float));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--ce-explore-gamma").help("Discount factor gamma for rewards awarded to CE Explore policy").withDefault(Double.valueOf(0.99d)).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Double));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--ce-explore-uniform").help("If true, our CE Explore policy will not be trained, but remain completely uniform").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--no-ce-explore-is").help("If true, we ignore importance sampling when doing CE Exploration").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--wis", "--weighted-importance-sampling").help("If true, we use Weighted Importance Sampling instead of Ordinary Importance Sampling for any of the above").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--no-value-learning").help("If true, we don't do any value function learning.").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--mcts-as-reg-pol-opt").help("If true, we use Act, Search, and Learn as described in the MCTS as regularized policy optimization paper for Biased MCTS.").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--max-biased-playout-actions", "--max-num-biased-playout-actions").help("Maximum number of actions per playout which we'll bias using features (-1 for no limit).").withDefault(-1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--no-prune-init-features").help("If true, we will keep full atomic feature set and not prune anything.").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--prune-init-features-threshold").help("Will only consider pruning features if they have been active at least this many times.").withDefault(50).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--num-pruning-games").help("Number of random games to play out for determining features to prune.").withDefault(1500).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--max-pruning-seconds").help("Maximum number of seconds to spend on random games for pruning initial featureset.").withDefault(60).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--checkpoint-type", "--checkpoints").help("When do we store checkpoints of trained weights?").withDefault(CheckpointTypes.Game.toString()).withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).withLegalVals(Arrays.stream(CheckpointTypes.values()).map((v0) -> {
            return v0.toString();
        }).toArray()));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--checkpoint-freq", "--checkpoint-frequency").help("Frequency of checkpoint updates").withDefault(1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--out-dir", "--output-directory").help("Filepath for output directory").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--no-logging").help("If true, we don't write a bunch of messages to a log file.").withType(CommandLineArgParse.OptionTypes.Boolean));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--useGUI").help("Whether to create a small GUI that can be used to manually interrupt training run. False by default."));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--max-wall-time").help("Max wall time in minutes (or -1 for no limit).").withDefault(-1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        if (commandLineArgParse.parseArguments(strArr)) {
            ExpertIteration expertIteration = new ExpertIteration(commandLineArgParse.getValueBool("--useGUI"), commandLineArgParse.getValueInt("--max-wall-time"));
            expertIteration.gameName = commandLineArgParse.getValueString("--game");
            expertIteration.gameOptions = (List) commandLineArgParse.getValue("--game-options");
            expertIteration.expertAI = commandLineArgParse.getValueString("--expert-ai");
            expertIteration.bestAgentsDataDir = commandLineArgParse.getValueString("--best-agents-data-dir");
            expertIteration.numTrainingGames = commandLineArgParse.getValueInt("-n");
            expertIteration.gameLengthCap = commandLineArgParse.getValueInt("--game-length-cap");
            expertIteration.thinkingTime = commandLineArgParse.getValueDouble("--thinking-time");
            expertIteration.iterationLimit = commandLineArgParse.getValueInt("--iteration-limit");
            expertIteration.depthLimit = commandLineArgParse.getValueInt("--depth-limit");
            expertIteration.addFeatureEvery = commandLineArgParse.getValueInt("--add-feature-every");
            expertIteration.batchSize = commandLineArgParse.getValueInt("--batch-size");
            expertIteration.experienceBufferSize = commandLineArgParse.getValueInt("--buffer-size");
            expertIteration.updateWeightsEvery = commandLineArgParse.getValueInt("--update-weights-every");
            expertIteration.noGrowFeatureSet = commandLineArgParse.getValueBool("--no-grow-features");
            expertIteration.trainTSPG = commandLineArgParse.getValueBool("--train-tspg");
            expertIteration.crossEntropyOptimiserConfig = commandLineArgParse.getValueString("--ce-optimiser");
            expertIteration.ceExploreOptimiserConfig = commandLineArgParse.getValueString("--cee-optimiser");
            expertIteration.tspgOptimiserConfig = commandLineArgParse.getValueString("--tspg-optimiser");
            expertIteration.valueOptimiserConfig = commandLineArgParse.getValueString("--value-optimiser");
            expertIteration.combiningFeatureInstanceThreshold = commandLineArgParse.getValueInt("--combining-feature-instance-threshold");
            expertIteration.importanceSamplingEpisodeDurations = commandLineArgParse.getValueBool("--is-episode-durations");
            expertIteration.prioritizedExperienceReplay = commandLineArgParse.getValueBool("--prioritized-experience-replay");
            expertIteration.ceExplore = commandLineArgParse.getValueBool("--ce-explore");
            expertIteration.ceExploreMix = commandLineArgParse.getValueFloat("--ce-explore-mix");
            expertIteration.ceExploreGamma = commandLineArgParse.getValueDouble("--ce-explore-gamma");
            expertIteration.ceExploreUniform = commandLineArgParse.getValueBool("--ce-explore-uniform");
            expertIteration.noCEExploreIS = commandLineArgParse.getValueBool("--no-ce-explore-is");
            expertIteration.weightedImportanceSampling = commandLineArgParse.getValueBool("--wis");
            expertIteration.noValueLearning = commandLineArgParse.getValueBool("--no-value-learning");
            expertIteration.mctsRegPolOpt = commandLineArgParse.getValueBool("--mcts-as-reg-pol-opt");
            expertIteration.maxNumBiasedPlayoutActions = commandLineArgParse.getValueInt("--max-num-biased-playout-actions");
            expertIteration.noPruneInitFeatures = commandLineArgParse.getValueBool("--no-prune-init-features");
            expertIteration.pruneInitFeaturesThreshold = commandLineArgParse.getValueInt("--prune-init-features-threshold");
            expertIteration.numPruningGames = commandLineArgParse.getValueInt("--num-pruning-games");
            expertIteration.maxNumPruningSeconds = commandLineArgParse.getValueInt("--max-pruning-seconds");
            expertIteration.checkpointType = CheckpointTypes.valueOf(commandLineArgParse.getValueString("--checkpoint-type"));
            expertIteration.checkpointFrequency = commandLineArgParse.getValueInt("--checkpoint-freq");
            String valueString = commandLineArgParse.getValueString("--out-dir");
            if (valueString != null) {
                expertIteration.outDir = new File(valueString);
            } else {
                expertIteration.outDir = null;
            }
            expertIteration.noLogging = commandLineArgParse.getValueBool("--no-logging");
            expertIteration.startExperiment();
        }
    }
}
