package supplementary.experiments.feature_sets;

import com.itextpdf.text.xml.xmp.XmpWriter;
import features.feature_sets.BaseFeatureSet;
import features.feature_sets.network.JITSPatterNetFeatureSet;
import features.feature_sets.network.SPatterNet;
import features.feature_sets.network.SPatterNetFeatureSet;
import function_approx.LinearFunction;
import game.Game;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
import main.CommandLineArgParse;
import main.StringRoutines;
import main.UnixPrintWriter;
import org.apache.batik.svggen.SVGSyntax;
import other.AI;
import other.GameLoader;
import other.context.Context;
import other.trial.Trial;
import policies.softmax.SoftmaxPolicyLinear;
import search.mcts.MCTS;
import search.mcts.backpropagation.MonteCarloBackprop;
import search.mcts.finalmoveselection.RobustChild;
import search.mcts.selection.AG0Selection;
import utils.ExperimentFileUtils;

/* loaded from: input_file:supplementary/experiments/feature_sets/MemoryUsage.class */
public class MemoryUsage {
    private static final String[] GAMES = {"Alquerque.lud", "Amazons.lud", "ArdRi.lud", "Arimaa.lud", "Ataxx.lud", "Bao Ki Arabu (Zanzibar 1).lud", "Bizingo.lud", "Breakthrough.lud", "Chess.lud", "Chinese Checkers.lud", "English Draughts.lud", "Fanorona.lud", "Fox and Geese.lud", "Go.lud", "Gomoku.lud", "Gonnect.lud", "Havannah.lud", "Hex.lud", "Kensington.lud", "Knightthrough.lud", "Konane.lud", "Level Chess.lud", "Lines of Action.lud", "Pentalath.lud", "Pretwa.lud", "Reversi.lud", "Royal Game of Ur.lud", "Surakarta.lud", "Shobu.lud", "Tablut.lud", "Triad.lud", "XII Scripta.lud", "Yavalath.lud"};

    private static void evalMemoryUsage(CommandLineArgParse commandLineArgParse) {
        String valueString = commandLineArgParse.getValueString("--training-out-dir");
        if (!valueString.endsWith("/")) {
            valueString = valueString + "/";
        }
        try {
            UnixPrintWriter unixPrintWriter = new UnixPrintWriter(new File(commandLineArgParse.getValueString("--out-file")), XmpWriter.UTF8);
            try {
                unixPrintWriter.println(StringRoutines.join(SVGSyntax.COMMA, "game", "spatternet_num_keys_proactive", "spatternet_num_keys_reactive", "spatternet_num_props_proactive", "spatternet_num_props_reactive", "jit_num_keys_proactive", "jit_num_keys_reactive", "jit_num_props_proactive", "jit_num_props_reactive", "keys_ratio", "keys_ratio_proactive", "keys_ratio_reactive", "props_ratio", "props_ratio_proactive", "props_ratio_reactive"));
                for (String str : GAMES) {
                    System.out.println("Game: " + str);
                    Game loadGameFromName = GameLoader.loadGameFromName(str);
                    int count = loadGameFromName.players().count();
                    String cleanGameName = StringRoutines.cleanGameName(str.replaceAll(Pattern.quote(".lud"), ""));
                    String[] strArr = new String[count + 1];
                    for (int i = 1; i <= count; i++) {
                        String str2 = valueString + cleanGameName + "/PolicyWeightsCE_P" + i + "_00201.txt";
                        if (!new File(str2).exists()) {
                            String parent = new File(str2).getParent();
                            str2 = str2.contains("Selection") ? ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeightsSelection_P" + i, "txt") : str2.contains("Playout") ? ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeightsPlayout_P" + i, "txt") : str2.contains("TSPG") ? ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeightsTSPG_P" + i, "txt") : str2.contains("PolicyWeightsCE") ? ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeightsCE_P" + i, "txt") : null;
                        }
                        if (str2 == null) {
                            System.err.println("Cannot resolve policy weights filepath: " + valueString + cleanGameName + "/PolicyWeightsCE_P" + i + "_00201.txt");
                        }
                        strArr[i] = str2;
                    }
                    LinearFunction[] linearFunctionArr = new LinearFunction[count + 1];
                    for (int i2 = 1; i2 <= count; i2++) {
                        linearFunctionArr[i2] = LinearFunction.fromFile(strArr[i2]);
                    }
                    BaseFeatureSet[] baseFeatureSetArr = new BaseFeatureSet[count + 1];
                    HashMap hashMap = new HashMap();
                    HashMap hashMap2 = new HashMap();
                    HashMap hashMap3 = new HashMap();
                    HashMap hashMap4 = new HashMap();
                    long j = 0;
                    long j2 = 0;
                    long j3 = 0;
                    long j4 = 0;
                    for (BaseFeatureSet.FeatureSetImplementations featureSetImplementations : new BaseFeatureSet.FeatureSetImplementations[]{BaseFeatureSet.FeatureSetImplementations.SPATTERNET, BaseFeatureSet.FeatureSetImplementations.JITSPATTERNET}) {
                        System.out.println("Implementation: " + featureSetImplementations);
                        for (int i3 = 1; i3 <= count; i3++) {
                            String str3 = new File(strArr[i3]).getParent() + File.separator + linearFunctionArr[i3].featureSetFile();
                            if (featureSetImplementations == BaseFeatureSet.FeatureSetImplementations.SPATTERNET) {
                                baseFeatureSetArr[i3] = new SPatterNetFeatureSet(str3);
                            } else if (featureSetImplementations == BaseFeatureSet.FeatureSetImplementations.JITSPATTERNET) {
                                baseFeatureSetArr[i3] = JITSPatterNetFeatureSet.construct(str3);
                            }
                        }
                        SoftmaxPolicyLinear softmaxPolicyLinear = new SoftmaxPolicyLinear(linearFunctionArr, baseFeatureSetArr);
                        ArrayList arrayList = new ArrayList();
                        arrayList.add(null);
                        for (int i4 = 1; i4 <= count; i4++) {
                            MCTS mcts = new MCTS(new AG0Selection(), softmaxPolicyLinear, new MonteCarloBackprop(), new RobustChild());
                            mcts.setLearnedSelectionPolicy(softmaxPolicyLinear);
                            mcts.setQInit(MCTS.QInit.WIN);
                            arrayList.add(mcts);
                        }
                        Trial trial = new Trial(loadGameFromName);
                        Context context = new Context(loadGameFromName, trial);
                        boolean z = true;
                        for (int i5 = 0; i5 < 60; i5++) {
                            if (z) {
                                z = false;
                                loadGameFromName.start(context);
                                long currentTimeMillis = System.currentTimeMillis();
                                for (int i6 = 1; i6 <= count; i6++) {
                                    ((AI) arrayList.get(i6)).initAI(loadGameFromName, i6);
                                }
                                System.out.println("init for " + count + " players took " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " seconds.");
                            }
                            context.model().startNewStep(context, arrayList, 1.0d);
                            if (trial.over()) {
                                z = true;
                            }
                        }
                        if (featureSetImplementations == BaseFeatureSet.FeatureSetImplementations.SPATTERNET) {
                            for (int i7 = 1; i7 <= count; i7++) {
                                System.out.println("p = " + i7);
                                SPatterNetFeatureSet sPatterNetFeatureSet = (SPatterNetFeatureSet) baseFeatureSetArr[i7];
                                HashMap<BaseFeatureSet.ReactiveFeaturesKey, SPatterNet> reactiveFeaturesThresholded = sPatterNetFeatureSet.reactiveFeaturesThresholded();
                                HashMap<BaseFeatureSet.ProactiveFeaturesKey, SPatterNet> proactiveFeaturesThresholded = sPatterNetFeatureSet.proactiveFeaturesThresholded();
                                for (Map.Entry<BaseFeatureSet.ReactiveFeaturesKey, SPatterNet> entry : reactiveFeaturesThresholded.entrySet()) {
                                    j2 += entry.getValue().numPropositions();
                                    hashMap.put(entry.getKey(), entry.getValue());
                                }
                                for (Map.Entry<BaseFeatureSet.ProactiveFeaturesKey, SPatterNet> entry2 : proactiveFeaturesThresholded.entrySet()) {
                                    j += entry2.getValue().numPropositions();
                                    hashMap2.put(entry2.getKey(), entry2.getValue());
                                }
                            }
                        } else if (featureSetImplementations == BaseFeatureSet.FeatureSetImplementations.JITSPATTERNET) {
                            for (int i8 = 1; i8 <= count; i8++) {
                                System.out.println("p = " + i8);
                                JITSPatterNetFeatureSet jITSPatterNetFeatureSet = (JITSPatterNetFeatureSet) baseFeatureSetArr[i8];
                                Map<BaseFeatureSet.MoveFeaturesKey, SPatterNet> spatterNetMapThresholded = jITSPatterNetFeatureSet.spatterNetMapThresholded();
                                for (BaseFeatureSet.MoveFeaturesKey moveFeaturesKey : spatterNetMapThresholded.keySet()) {
                                    if (moveFeaturesKey instanceof BaseFeatureSet.ReactiveFeaturesKey) {
                                        if (hashMap3.put((BaseFeatureSet.ReactiveFeaturesKey) moveFeaturesKey, spatterNetMapThresholded.get(moveFeaturesKey)) == null) {
                                            j4 += spatterNetMapThresholded.get(moveFeaturesKey).numPropositions();
                                        }
                                    } else if (hashMap4.put((BaseFeatureSet.ProactiveFeaturesKey) moveFeaturesKey, spatterNetMapThresholded.get(moveFeaturesKey)) == null) {
                                        j3 += spatterNetMapThresholded.get(moveFeaturesKey).numPropositions();
                                    }
                                }
                                for (BaseFeatureSet.MoveFeaturesKey moveFeaturesKey2 : jITSPatterNetFeatureSet.spatterNetMap().keySet()) {
                                    if (moveFeaturesKey2 instanceof BaseFeatureSet.ReactiveFeaturesKey) {
                                        if (hashMap3.put((BaseFeatureSet.ReactiveFeaturesKey) moveFeaturesKey2, jITSPatterNetFeatureSet.spatterNetMap().get(moveFeaturesKey2)) == null) {
                                            j4 += jITSPatterNetFeatureSet.spatterNetMap().get(moveFeaturesKey2).numPropositions();
                                        }
                                    } else if (hashMap4.put((BaseFeatureSet.ProactiveFeaturesKey) moveFeaturesKey2, jITSPatterNetFeatureSet.spatterNetMap().get(moveFeaturesKey2)) == null) {
                                        j3 += jITSPatterNetFeatureSet.spatterNetMap().get(moveFeaturesKey2).numPropositions();
                                    }
                                }
                            }
                        }
                        System.out.println();
                    }
                    System.out.println();
                    ArrayList arrayList2 = new ArrayList();
                    arrayList2.add(StringRoutines.quote(str));
                    arrayList2.add(String.valueOf(hashMap2.size()));
                    arrayList2.add(String.valueOf(hashMap.size()));
                    arrayList2.add(String.valueOf(j));
                    arrayList2.add(String.valueOf(j2));
                    arrayList2.add(String.valueOf(hashMap4.size()));
                    arrayList2.add(String.valueOf(hashMap3.size()));
                    arrayList2.add(String.valueOf(j3));
                    arrayList2.add(String.valueOf(j4));
                    arrayList2.add(String.valueOf((hashMap2.size() + hashMap.size()) / (hashMap4.size() + hashMap3.size())));
                    arrayList2.add(String.valueOf(hashMap2.size() / hashMap4.size()));
                    arrayList2.add(String.valueOf(hashMap.size() / hashMap3.size()));
                    arrayList2.add(String.valueOf((j + j2) / (j3 + j4)));
                    arrayList2.add(String.valueOf(j / j3));
                    arrayList2.add(String.valueOf(j2 / j4));
                    unixPrintWriter.println(StringRoutines.join(SVGSyntax.COMMA, arrayList2));
                }
                unixPrintWriter.close();
            } finally {
            }
        } catch (FileNotFoundException | UnsupportedEncodingException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Eval memory usage of feature sets.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--training-out-dir").help("Output directory for training results.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--out-file").help("Filepath to write our output CSV to.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        if (commandLineArgParse.parseArguments(strArr)) {
            evalMemoryUsage(commandLineArgParse);
        }
    }
}
