package supplementary.experiments.eval;

import compiler.Compiler;
import features.feature_sets.network.JITSPatterNetFeatureSet;
import game.Game;
import gnu.trove.list.array.TIntArrayList;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.regex.Pattern;
import java.util.stream.IntStream;
import main.CommandLineArgParse;
import main.FileHandling;
import main.StringRoutines;
import main.collections.ListUtils;
import main.grammar.Report;
import metadata.ai.agents.BestAgent;
import metadata.ai.features.Features;
import metadata.ai.heuristics.Heuristics;
import org.apache.batik.constants.XMLConstants;
import other.AI;
import other.GameLoader;
import other.RankUtils;
import other.context.Context;
import other.model.Model;
import other.trial.Trial;
import policies.softmax.SoftmaxPolicyLinear;
import search.mcts.MCTS;
import search.minimax.AlphaBetaSearch;
import utils.AIFactory;
import utils.AIUtils;
import utils.experiments.InterruptableExperiment;
import utils.experiments.ResultsSummary;

/* loaded from: input_file:supplementary/experiments/eval/EvalGate.class */
public class EvalGate {
    protected String gameName;
    protected List<String> gameOptions;
    protected String ruleset;
    protected int numGames;
    protected int gameLengthCap;
    protected double thinkingTime;
    protected int warmingUpSecs;
    protected String evalAgent;
    protected List<String> evalFeatureWeightsFilepaths;
    protected String evalHeuristicsFilepath;
    protected File bestAgentsDataDir;
    protected String gateAgentType;
    protected boolean useGUI;
    protected int maxWallTime;

    private EvalGate(boolean z, int i) {
        this.useGUI = z;
        this.maxWallTime = i;
    }

    private AI createEvalAI() {
        if (this.evalAgent.equals("Alpha-Beta")) {
            return AIFactory.createAI("algorithm=Alpha-Beta;heuristics=" + this.evalHeuristicsFilepath);
        }
        if (this.evalAgent.equals("BiasedMCTS")) {
            StringBuilder sb = new StringBuilder();
            sb.append("playout=softmax");
            for (int i = 1; i <= this.evalFeatureWeightsFilepaths.size(); i++) {
                sb.append(",policyweights" + i + XMLConstants.XML_EQUAL_SIGN + this.evalFeatureWeightsFilepaths.get(i - 1));
            }
            return AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", sb.toString(), "final_move=robustchild", "tree_reuse=true", "learned_selection_policy=playout", "friendly_name=BiasedMCTS"));
        }
        if (!this.evalAgent.equals("BiasedMCTSUniformPlayouts")) {
            System.err.println("Can't build eval AI: " + this.evalAgent);
            return null;
        }
        StringBuilder sb2 = new StringBuilder();
        sb2.append("learned_selection_policy=softmax");
        for (int i2 = 1; i2 <= this.evalFeatureWeightsFilepaths.size(); i2++) {
            sb2.append(",policyweights" + i2 + XMLConstants.XML_EQUAL_SIGN + this.evalFeatureWeightsFilepaths.get(i2 - 1));
        }
        return AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", "playout=random", "final_move=robustchild", "tree_reuse=true", sb2.toString(), "friendly_name=BiasedMCTSUniformPlayouts"));
    }

    private AI createGateAI() {
        String replaceAll = this.bestAgentsDataDir.getAbsolutePath().replaceAll(Pattern.quote("\\"), "/");
        Report report = new Report();
        try {
            if (this.gateAgentType.equals("BestAgent")) {
                BestAgent bestAgent = (BestAgent) Compiler.compileObject(FileHandling.loadTextContentsFromFile(replaceAll + "/BestAgent.txt"), "metadata.ai.agents.BestAgent", report);
                if (bestAgent.agent().equals("AlphaBeta") || bestAgent.agent().equals("Alpha-Beta")) {
                    return new AlphaBetaSearch(replaceAll + "/BestHeuristics.txt");
                }
                if (bestAgent.agent().equals("AlphaBetaMetadata")) {
                    return new AlphaBetaSearch();
                }
                if (bestAgent.agent().equals("UCT")) {
                    return AIFactory.createAI("UCT");
                }
                if (bestAgent.agent().equals("MC-GRAVE")) {
                    return AIFactory.createAI("MC-GRAVE");
                }
                if (bestAgent.agent().equals("MAST")) {
                    return AIFactory.createAI("MAST");
                }
                if (bestAgent.agent().equals("ProgressiveHistory") || bestAgent.agent().equals("Progressive History")) {
                    return AIFactory.createAI("Progressive History");
                }
                if (bestAgent.agent().equals("Biased MCTS")) {
                    return MCTS.createBiasedMCTS((Features) Compiler.compileObject(FileHandling.loadTextContentsFromFile(replaceAll + "/BestFeatures.txt"), "metadata.ai.features.Features", report), 1.0d);
                }
                if (bestAgent.agent().equals("Biased MCTS (Uniform Playouts)")) {
                    return MCTS.createBiasedMCTS((Features) Compiler.compileObject(FileHandling.loadTextContentsFromFile(replaceAll + "/BestFeatures.txt"), "metadata.ai.features.Features", report), 0.0d);
                }
                System.err.println("Unrecognised best agent: " + bestAgent.agent());
            } else {
                if (this.gateAgentType.equals("Alpha-Beta")) {
                    return new AlphaBetaSearch(replaceAll + "/BestHeuristics.txt");
                }
                if (this.gateAgentType.equals("BiasedMCTS")) {
                    Features features2 = (Features) Compiler.compileObject(FileHandling.loadTextContentsFromFile(replaceAll + "/BestFeatures.txt"), "metadata.ai.features.Features", report);
                    if (this.evalAgent.equals("BiasedMCTS")) {
                        return MCTS.createBiasedMCTS(features2, 1.0d);
                    }
                    if (this.evalAgent.equals("BiasedMCTSUniformPlayouts")) {
                        return MCTS.createBiasedMCTS(features2, 0.0d);
                    }
                    System.err.println("Trying to use Biased MCTS gate when evaluating something other than Biased MCTS!");
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        System.err.println("Failed to build gate AI: " + this.gateAgentType);
        return null;
    }

    public void startExperiment() {
        Game loadGameFromName = (this.ruleset == null || this.ruleset.equals("")) ? GameLoader.loadGameFromName(this.gameName, this.gameOptions) : GameLoader.loadGameFromName(this.gameName, this.ruleset);
        final int count = loadGameFromName.players().count();
        if (this.gameLengthCap >= 0) {
            loadGameFromName.setMaxTurns(Math.min(this.gameLengthCap, loadGameFromName.getMaxTurnLimit()));
        }
        final Context context = new Context(loadGameFromName, new Trial(loadGameFromName));
        final ArrayList arrayList = new ArrayList(count % 2 == 0 ? count : count + 1);
        for (int i = 0; i < count; i += 2) {
            AI createEvalAI = createEvalAI();
            createEvalAI.setFriendlyName("EvalAI");
            AI createGateAI = createGateAI();
            createGateAI.setFriendlyName("GateAI");
            arrayList.add(createEvalAI);
            arrayList.add(createGateAI);
        }
        final Game game2 = loadGameFromName;
        new InterruptableExperiment(this.useGUI, this.maxWallTime) { // from class: supplementary.experiments.eval.EvalGate.1
            @Override // utils.experiments.InterruptableExperiment
            public void runExperiment() {
                PrintWriter printWriter;
                int i2 = EvalGate.this.numGames;
                new ArrayList();
                List<TIntArrayList> generatePermutations = ListUtils.generatePermutations(TIntArrayList.wrap(IntStream.range(0, count).toArray()));
                if (i2 % generatePermutations.size() != 0) {
                    i2 += i2 % generatePermutations.size();
                }
                double nanoTime = System.nanoTime() + (EvalGate.this.warmingUpSecs * 1.0E9d);
                for (long j = 0; j < nanoTime; j = System.nanoTime()) {
                    game2.start(context);
                    game2.playout(context, null, 1.0d, null, -1, -1, ThreadLocalRandom.current());
                }
                System.gc();
                ArrayList arrayList2 = new ArrayList();
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    arrayList2.add(((AI) it.next()).friendlyName());
                }
                ResultsSummary resultsSummary = new ResultsSummary(game2, arrayList2);
                for (int i3 = 0; i3 < i2; i3++) {
                    checkWallTime(0.05d);
                    if (this.interrupted) {
                        break;
                    }
                    ArrayList arrayList3 = new ArrayList(count);
                    TIntArrayList tIntArrayList = generatePermutations.get(i3 % generatePermutations.size());
                    arrayList3.add(null);
                    for (int i4 = 0; i4 < tIntArrayList.size(); i4++) {
                        arrayList3.add((AI) arrayList.get(tIntArrayList.getQuick(i4) % arrayList.size()));
                    }
                    game2.start(context);
                    for (int i5 = 1; i5 < arrayList3.size(); i5++) {
                        ((AI) arrayList3.get(i5)).initAI(game2, i5);
                    }
                    Model model = context.model();
                    while (!context.trial().over() && !this.interrupted) {
                        model.startNewStep(context, arrayList3, EvalGate.this.thinkingTime, -1, -1, 0.0d);
                    }
                    if (context.trial().over()) {
                        double[] agentUtilities = RankUtils.agentUtilities(context);
                        int numMoves = context.trial().numMoves() - context.trial().numInitialPlacementMoves();
                        int[] iArr = new int[tIntArrayList.size() + 1];
                        tIntArrayList.toArray(iArr, 0, 1, tIntArrayList.size());
                        resultsSummary.recordResults(iArr, agentUtilities, numMoves);
                    }
                    for (int i6 = 1; i6 < arrayList3.size(); i6++) {
                        ((AI) arrayList3.get(i6)).closeAI();
                    }
                }
                double avgScoreForAgentName = resultsSummary.avgScoreForAgentName("EvalAI");
                double avgScoreForAgentName2 = resultsSummary.avgScoreForAgentName("GateAI");
                System.out.println("----------------------------------");
                System.out.println("Eval Agent = " + EvalGate.this.evalAgent);
                System.out.println("Gate Agent = " + EvalGate.this.gateAgentType);
                System.out.println();
                System.out.println("Eval Agent Score = " + avgScoreForAgentName);
                System.out.println("Gate Agent Score = " + avgScoreForAgentName2);
                System.out.println("----------------------------------");
                if (avgScoreForAgentName > avgScoreForAgentName2) {
                    boolean z = false;
                    boolean z2 = false;
                    boolean z3 = false;
                    if (EvalGate.this.gateAgentType.equals("BestAgent")) {
                        z = true;
                        if (EvalGate.this.evalAgent.equals("Alpha-Beta")) {
                            z3 = true;
                        } else if (EvalGate.this.evalAgent.contains("BiasedMCTS")) {
                            z2 = true;
                        } else {
                            System.err.println("Eval agent is neither Alpha-Beta nor a variant of BiasedMCTS");
                        }
                    } else if (EvalGate.this.gateAgentType.equals("Alpha-Beta")) {
                        if (EvalGate.this.evalAgent.equals("Alpha-Beta")) {
                            z3 = true;
                        } else {
                            System.err.println("evalAgent = " + EvalGate.this.evalAgent + " against gateAgentType = " + EvalGate.this.gateAgentType);
                        }
                    } else if (!EvalGate.this.gateAgentType.equals("BiasedMCTS")) {
                        System.err.println("Unrecognised gate agent type: " + EvalGate.this.gateAgentType);
                    } else if (EvalGate.this.evalAgent.contains("BiasedMCTS")) {
                        z2 = true;
                    } else {
                        System.err.println("evalAgent = " + EvalGate.this.evalAgent + " against gateAgentType = " + EvalGate.this.gateAgentType);
                    }
                    String replaceAll = EvalGate.this.bestAgentsDataDir.getAbsolutePath().replaceAll(Pattern.quote("\\"), "/");
                    if (z) {
                        try {
                            printWriter = new PrintWriter(new File(replaceAll + "/BestAgent.txt"));
                            try {
                                BestAgent bestAgent = EvalGate.this.evalAgent.equals("Alpha-Beta") ? new BestAgent("AlphaBeta") : EvalGate.this.evalAgent.equals("BiasedMCTS") ? new BestAgent("Biased MCTS") : EvalGate.this.evalAgent.equals("BiasedMCTSUniformPlayouts") ? new BestAgent("Biased MCTS (Uniform Playouts)") : null;
                                System.out.println("Writing new best agent: " + EvalGate.this.evalAgent);
                                printWriter.println(bestAgent.toString());
                                printWriter.close();
                            } finally {
                            }
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                    }
                    if (z3) {
                        try {
                            printWriter = new PrintWriter(new File(replaceAll + "/BestHeuristics.txt"));
                            try {
                                Heuristics heuristics = (Heuristics) Compiler.compileObject(FileHandling.loadTextContentsFromFile(EvalGate.this.evalHeuristicsFilepath), "metadata.ai.heuristics.Heuristics", new Report());
                                System.out.println("writing new best heuristics");
                                printWriter.println(heuristics.toString());
                                printWriter.close();
                            } finally {
                            }
                        } catch (IOException e2) {
                            e2.printStackTrace();
                        }
                    }
                    if (z2) {
                        File file = new File(replaceAll + "/BestFeatures.txt");
                        StringBuilder sb = new StringBuilder();
                        sb.append("playout=softmax");
                        for (int i7 = 1; i7 <= EvalGate.this.evalFeatureWeightsFilepaths.size(); i7++) {
                            sb.append(",policyweights" + i7 + XMLConstants.XML_EQUAL_SIGN + EvalGate.this.evalFeatureWeightsFilepaths.get(i7 - 1));
                        }
                        MCTS mcts = (MCTS) AIFactory.createAI(StringRoutines.join(XMLConstants.XML_CHAR_REF_SUFFIX, "algorithm=MCTS", "selection=noisyag0selection", sb.toString(), "final_move=robustchild", "tree_reuse=true", "learned_selection_policy=playout", "friendly_name=BiasedMCTS"));
                        Features generateFeaturesMetadata = AIUtils.generateFeaturesMetadata((SoftmaxPolicyLinear) mcts.learnedSelectionPolicy(), (SoftmaxPolicyLinear) mcts.playoutStrategy());
                        try {
                            printWriter = new PrintWriter(file);
                            try {
                                System.out.println("writing new best features");
                                printWriter.println(generateFeaturesMetadata.toString());
                                printWriter.close();
                            } finally {
                                try {
                                    printWriter.close();
                                } catch (Throwable th) {
                                    th.addSuppressed(th);
                                }
                            }
                        } catch (IOException e3) {
                            e3.printStackTrace();
                        }
                    }
                }
            }
        };
    }

    public static void main(String[] strArr) {
        JITSPatterNetFeatureSet.ALLOW_FEATURE_SET_CACHE = true;
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Gating experiment to test if a newly-trained agent outperforms current best.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game").help("Name of the game to play. Should end with \".lud\".").withDefault("Amazons.lud").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game-options").help("Game Options to load.").withDefault(new ArrayList(0)).withNumVals("*").withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--ruleset").help("Ruleset to compile.").withDefault("").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--eval-agent").help("Agent to be evaluated.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).withLegalVals("Alpha-Beta", "BiasedMCTS", "BiasedMCTSUniformPlayouts").setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--eval-feature-weights-filepaths").help("Filepaths for feature weights to be evaluated.").withNumVals("*").withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--eval-heuristics-filepath").help("Filepath for heuristics to be evaluated.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("-n", "--num-games", "--num-eval-games").help("Number of training games to run.").withDefault(200).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--game-length-cap", "--max-num-actions").help("Maximum number of actions that may be taken before a game is terminated as a draw (-1 for no limit).").withDefault(-1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--thinking-time", "--time", "--seconds").help("Max allowed thinking time per move (in seconds).").withDefault(Double.valueOf(1.0d)).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Double));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--warming-up-secs").help("Number of seconds for which to warm up JVM.").withType(CommandLineArgParse.OptionTypes.Int).withNumVals(1).withDefault(60));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--best-agents-data-dir").help("Filepath for directory containing data on best agents").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--gate-agent-type").help("Type of gate agent against which we wish to evaluate.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired().withLegalVals("BestAgent", "Alpha-Beta", "BiasedMCTS"));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--useGUI").help("Whether to create a small GUI that can be used to manually interrupt training run. False by default."));
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--max-wall-time").help("Max wall time in minutes (or -1 for no limit).").withDefault(-1).withNumVals(1).withType(CommandLineArgParse.OptionTypes.Int));
        if (commandLineArgParse.parseArguments(strArr)) {
            EvalGate evalGate = new EvalGate(commandLineArgParse.getValueBool("--useGUI"), commandLineArgParse.getValueInt("--max-wall-time"));
            evalGate.gameName = commandLineArgParse.getValueString("--game");
            evalGate.gameOptions = (List) commandLineArgParse.getValue("--game-options");
            evalGate.ruleset = commandLineArgParse.getValueString("--ruleset");
            evalGate.evalAgent = commandLineArgParse.getValueString("--eval-agent");
            evalGate.evalFeatureWeightsFilepaths = (List) commandLineArgParse.getValue("--eval-feature-weights-filepaths");
            evalGate.evalHeuristicsFilepath = commandLineArgParse.getValueString("--eval-heuristics-filepath");
            evalGate.numGames = commandLineArgParse.getValueInt("-n");
            evalGate.gameLengthCap = commandLineArgParse.getValueInt("--game-length-cap");
            evalGate.thinkingTime = commandLineArgParse.getValueDouble("--thinking-time");
            evalGate.warmingUpSecs = commandLineArgParse.getValueInt("--warming-up-secs");
            evalGate.bestAgentsDataDir = new File(commandLineArgParse.getValueString("--best-agents-data-dir"));
            evalGate.gateAgentType = commandLineArgParse.getValueString("--gate-agent-type");
            evalGate.startExperiment();
        }
    }
}
