package supplementary.experiments.scripts;

import com.itextpdf.text.xml.xmp.XmpWriter;
import features.spatial.Walk;
import game.Game;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
import main.CommandLineArgParse;
import main.FileHandling;
import main.StringRoutines;
import main.UnixPrintWriter;
import main.collections.ArrayUtils;
import main.options.Ruleset;
import org.apache.batik.constants.XMLConstants;
import org.apache.batik.svggen.SVGSyntax;
import org.apache.batik.util.SVGConstants;
import other.GameLoader;
import supplementary.experiments.analysis.RulesetConceptsUCT;
import utils.RulesetNames;

/* loaded from: input_file:supplementary/experiments/scripts/ExItTrainingScriptsGenSnelliusAllGames.class */
public class ExItTrainingScriptsGenSnelliusAllGames {
    private static final int MAX_JOBS_PER_BATCH = 800;
    private static final String JVM_MEM = "5120";
    private static final int MEM_PER_PROCESS = 6;
    private static final int MEM_PER_NODE = 256;
    private static final int MAX_REQUEST_MEM = 224;
    private static final int MAX_SELFPLAY_TRIALS = 200;
    private static final int MAX_WALL_TIME = 1445;
    private static final int CORES_PER_NODE = 128;
    private static final int CORES_PER_PROCESS = 3;
    private static final int EXCLUSIVE_CORES_THRESHOLD = 96;
    private static final int EXCLUSIVE_PROCESSES_THRESHOLD = 32;
    private static final int PROCESSES_PER_JOB = 42;
    private static final String[] SKIP_GAMES = {"Chinese Checkers.lud", "Li'b al-'Aqil.lud", "Li'b al-Ghashim.lud", "Mini Wars.lud", "Pagade Kayi Ata (Sixteen-handed).lud", "Taikyoku Shogi.lud"};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:supplementary/experiments/scripts/ExItTrainingScriptsGenSnelliusAllGames$ProcessData.class */
    public static class ProcessData {
        public final String gameName;
        public final String rulesetName;
        public final int numPlayers;

        public ProcessData(String str, String str2, int i) {
            this.gameName = str;
            this.rulesetName = str2;
            this.numPlayers = i;
        }
    }

    private ExItTrainingScriptsGenSnelliusAllGames() {
    }

    private static void generateScripts(CommandLineArgParse commandLineArgParse) {
        Game game2;
        ArrayList arrayList = new ArrayList();
        String replaceAll = commandLineArgParse.getValueString("--scripts-dir").replaceAll(Pattern.quote("\\"), "/");
        if (!replaceAll.endsWith("/")) {
            replaceAll = replaceAll + "/";
        }
        String valueString = commandLineArgParse.getValueString("--user-name");
        String[] strArr = (String[]) Arrays.stream(FileHandling.listGames()).filter(str -> {
            return (str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/bad/") || str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/wip/") || str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/WishlistDLP/") || str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/test/") || str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/wishlist/") || str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/reconstruction/") || str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/simulation/") || str.replaceAll(Pattern.quote("\\"), "/").contains("/lud/proprietary/")) ? false : true;
        }).toArray(i -> {
            return new String[i];
        });
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        final TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (String str2 : strArr) {
            String[] split = str2.replaceAll(Pattern.quote("\\"), "/").split(Pattern.quote("/"));
            String str3 = split[split.length - 1];
            boolean z = false;
            String[] strArr2 = SKIP_GAMES;
            int length = strArr2.length;
            int i2 = 0;
            while (true) {
                if (i2 >= length) {
                    break;
                }
                if (str3.endsWith(strArr2[i2])) {
                    z = true;
                    break;
                }
                i2++;
            }
            if (!z) {
                Game loadGameFromName = GameLoader.loadGameFromName(str2);
                ArrayList<Ruleset> arrayList4 = new ArrayList(loadGameFromName.description().rulesets());
                arrayList4.add(null);
                boolean z2 = false;
                for (Ruleset ruleset : arrayList4) {
                    String str4 = "";
                    if (ruleset != null || !z2) {
                        if (ruleset != null && !ruleset.optionSettings().isEmpty()) {
                            str4 = ruleset.heading();
                            z2 = true;
                            game2 = GameLoader.loadGameFromName(str2, str4);
                        } else if (ruleset == null || !ruleset.optionSettings().isEmpty()) {
                            game2 = loadGameFromName;
                        }
                        if (!game2.hasSubgames() && !game2.isDeductionPuzzle() && !game2.isSimulationMoveGame() && game2.isAlternatingMoveGame() && !game2.isStacking() && !game2.isBoardless() && !game2.hiddenInformation() && Walk.allGameRotations(game2).length != 0 && game2.players().count() != 0 && !game2.isSimultaneousMoveGame()) {
                            double value = RulesetConceptsUCT.getValue(RulesetNames.gameRulesetName(game2), "DurationMoves");
                            if (Double.isNaN(value)) {
                                value = Double.MAX_VALUE;
                            }
                            arrayList2.add("/" + str3);
                            arrayList3.add(str4);
                            tDoubleArrayList.add(value);
                            tIntArrayList.add(game2.players().count());
                        }
                    }
                }
            }
        }
        List<Integer> sortedIndices = ArrayUtils.sortedIndices(arrayList2.size(), new Comparator<Integer>() { // from class: supplementary.experiments.scripts.ExItTrainingScriptsGenSnelliusAllGames.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                double quick = TDoubleArrayList.this.getQuick(num2.intValue()) - TDoubleArrayList.this.getQuick(num.intValue());
                if (quick < 0.0d) {
                    return -1;
                }
                return quick > 0.0d ? 1 : 0;
            }
        });
        ArrayList arrayList5 = new ArrayList();
        Iterator<Integer> it = sortedIndices.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            arrayList5.add(new ProcessData((String) arrayList2.get(intValue), (String) arrayList3.get(intValue), tIntArrayList.getQuick(intValue)));
        }
        double d = 0.0d;
        int i3 = 0;
        while (i3 < arrayList5.size()) {
            String str5 = "TrainFeatures_" + arrayList.size() + ".sh";
            try {
                UnixPrintWriter unixPrintWriter = new UnixPrintWriter(new File(replaceAll + str5), XmpWriter.UTF8);
                try {
                    unixPrintWriter.println("#!/bin/bash");
                    unixPrintWriter.println("#SBATCH -J TrainFeatures");
                    unixPrintWriter.println("#SBATCH -p thin");
                    unixPrintWriter.println("#SBATCH -o /home/" + valueString + "/TrainFeaturesSnelliusAllGames/Out/Out_%J.out");
                    unixPrintWriter.println("#SBATCH -e /home/" + valueString + "/TrainFeaturesSnelliusAllGames/Out/Err_%J.err");
                    unixPrintWriter.println("#SBATCH -t 1445");
                    unixPrintWriter.println("#SBATCH -N 1");
                    int min = Math.min(arrayList5.size() - i3, 42);
                    boolean z3 = min > 32;
                    int min2 = z3 ? Math.min(256, 224) : Math.min(min * 6, 224);
                    d += 3082.6666666666665d;
                    unixPrintWriter.println("#SBATCH --cpus-per-task=" + (min * 3));
                    unixPrintWriter.println("#SBATCH --mem=" + min2 + SVGConstants.SVG_G_VALUE);
                    if (z3) {
                        unixPrintWriter.println("#SBATCH --exclusive");
                    } else {
                        unixPrintWriter.println("#SBATCH --exclusive");
                    }
                    unixPrintWriter.println("module load 2021");
                    unixPrintWriter.println("module load Java/11.0.2");
                    for (int i4 = 0; i4 < min; i4++) {
                        ProcessData processData = (ProcessData) arrayList5.get(i3);
                        unixPrintWriter.println(((((StringRoutines.join(" ", "taskset", "-c", StringRoutines.join(SVGSyntax.COMMA, String.valueOf(i4 * 3), String.valueOf((i4 * 3) + 1), String.valueOf((i4 * 3) + 2)), "java", "-Xms5120M", "-Xmx5120M", "-XX:+HeapDumpOnOutOfMemoryError", "-da", "-dsa", "-XX:+UseStringDeduplication", "-jar", StringRoutines.quote("/home/" + valueString + "/TrainFeaturesSnelliusAllGames/Ludii.jar"), "--expert-iteration", "--game", StringRoutines.quote(processData.gameName), "--ruleset", StringRoutines.quote(processData.rulesetName), "-n", String.valueOf(200), "--game-length-cap 1000", "--thinking-time 1", "--iteration-limit 12000", "--wis", "--playout-features-epsilon 0.5", "--no-value-learning", "--train-tspg", "--checkpoint-freq 20", "--num-agent-threads", String.valueOf(3), "--num-feature-discovery-threads", String.valueOf(Math.min(processData.numPlayers, 3)), "--out-dir", StringRoutines.quote("/home/" + valueString + "/TrainFeaturesSnelliusAllGames/Out" + StringRoutines.cleanGameName(processData.gameName.replaceAll(Pattern.quote(".lud"), "")) + "_" + StringRoutines.cleanRulesetName(processData.rulesetName).replaceAll(Pattern.quote("/"), "_") + "/"), "--no-logging", "--max-wall-time", String.valueOf(MAX_WALL_TIME)) + " --special-moves-expander-split") + " --handle-aliasing") + " --is-episode-durations") + " --prioritized-experience-replay") + " " + StringRoutines.join(" ", XMLConstants.XML_CLOSE_TAG_END, "/home/" + valueString + "/TrainFeaturesSnelliusAllGames/Out/Out_${SLURM_JOB_ID}_" + i4 + ".out", "2>", "/home/" + valueString + "/TrainFeaturesSnelliusAllGames/Out/Err_${SLURM_JOB_ID}_" + i4 + ".err", "&"));
                        i3++;
                    }
                    unixPrintWriter.println("wait");
                    arrayList.add(str5);
                    unixPrintWriter.close();
                } catch (Throwable th) {
                    try {
                        unixPrintWriter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                    break;
                }
            } catch (FileNotFoundException | UnsupportedEncodingException e) {
                e.printStackTrace();
            }
        }
        ArrayList arrayList6 = new ArrayList();
        List list = arrayList;
        while (true) {
            List list2 = list;
            if (list2.size() <= 0) {
                break;
            }
            if (list2.size() > MAX_JOBS_PER_BATCH) {
                ArrayList arrayList7 = new ArrayList();
                for (int i5 = 0; i5 < MAX_JOBS_PER_BATCH; i5++) {
                    arrayList7.add((String) list2.get(i5));
                }
                arrayList6.add(arrayList7);
                list = list2.subList(MAX_JOBS_PER_BATCH, list2.size());
            } else {
                arrayList6.add(list2);
                list = new ArrayList();
            }
        }
        for (int i6 = 0; i6 < arrayList6.size(); i6++) {
            try {
                UnixPrintWriter unixPrintWriter2 = new UnixPrintWriter(new File(replaceAll + "SubmitJobs_Part" + i6 + ".sh"), XmpWriter.UTF8);
                try {
                    Iterator it2 = ((List) arrayList6.get(i6)).iterator();
                    while (it2.hasNext()) {
                        unixPrintWriter2.println("sbatch " + ((String) it2.next()));
                    }
                    unixPrintWriter2.close();
                } catch (Throwable th3) {
                    try {
                        unixPrintWriter2.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                    throw th3;
                    break;
                }
            } catch (FileNotFoundException | UnsupportedEncodingException e2) {
                e2.printStackTrace();
            }
        }
        System.out.println("Total core hours requested = " + d);
    }

    public static void main(String[] strArr) {
        CommandLineArgParse commandLineArgParse = new CommandLineArgParse(true, "Creating feature training job scripts for Snellius cluster.");
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--user-name").help("Username on the cluster.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        commandLineArgParse.addOption(new CommandLineArgParse.ArgOption().withNames("--scripts-dir").help("Directory in which to store generated scripts.").withNumVals(1).withType(CommandLineArgParse.OptionTypes.String).setRequired());
        if (commandLineArgParse.parseArguments(strArr)) {
            generateScripts(commandLineArgParse);
        }
    }
}
