package search.mcts;

import expert_iteration.ExItExperience;
import expert_iteration.ExpertPolicy;
import game.Game;
import java.util.List;
import main.collections.FVector;
import main.collections.FastArrayList;
import metadata.ai.features.Features;
import org.apache.batik.svggen.SVGSyntax;
import org.apache.batik.util.SVGConstants;
import org.json.JSONObject;
import policies.softmax.SoftmaxFromMetadata;
import policies.softmax.SoftmaxPolicy;
import search.mcts.backpropagation.Backpropagation;
import search.mcts.finalmoveselection.ActRegPolOpt;
import search.mcts.finalmoveselection.FinalMoveSelectionStrategy;
import search.mcts.finalmoveselection.MaxAvgScore;
import search.mcts.finalmoveselection.ProportionalExpVisitCount;
import search.mcts.finalmoveselection.RobustChild;
import search.mcts.nodes.BaseNode;
import search.mcts.nodes.Node;
import search.mcts.nodes.OpenLoopNode;
import search.mcts.playout.PlayoutStrategy;
import search.mcts.playout.RandomPlayout;
import search.mcts.selection.AG0Selection;
import search.mcts.selection.SearchRegPolOpt;
import search.mcts.selection.SelectionStrategy;
import search.mcts.selection.UCB1;
import util.AI;
import util.Context;
import util.Move;
import utils.AIUtils;

/* loaded from: input_file:search/mcts/MCTS.class */
public class MCTS extends ExpertPolicy {
    protected SelectionStrategy selectionStrategy;
    protected PlayoutStrategy playoutStrategy;
    protected Backpropagation backpropagation;
    protected FinalMoveSelectionStrategy finalMoveSelectionStrategy;
    protected int backpropFlags;
    protected BaseNode rootNode = null;
    protected QInit qInit = QInit.PARENT;
    protected double autoPlaySeconds = 0.0d;
    protected long currentGameFlags = 0;
    protected int lastNumMctsIterations = -1;
    protected int lastNumPlayoutActions = -1;
    protected double lastReturnedMoveValueEst = 0.0d;
    protected String analysisReport = null;
    protected boolean preserveRootNode = false;
    protected boolean treeReuse = true;
    protected int lastActionHistorySize = 0;
    protected SoftmaxPolicy learnedSelectionPolicy = null;

    /* loaded from: input_file:search/mcts/MCTS$QInit.class */
    public enum QInit {
        INF,
        LOSS,
        DRAW,
        WIN,
        PARENT
    }

    public static MCTS createUCT() {
        return createUCT(Math.sqrt(2.0d));
    }

    public static MCTS createUCT(double d) {
        MCTS mcts = new MCTS(new UCB1(d), new RandomPlayout(200), new RobustChild());
        mcts.friendlyName = "UCT";
        return mcts;
    }

    public static MCTS createBiasedMCTS(boolean z) {
        SoftmaxFromMetadata softmaxFromMetadata = new SoftmaxFromMetadata();
        MCTS mcts = new MCTS(new AG0Selection(), z ? softmaxFromMetadata : new RandomPlayout(200), new RobustChild());
        mcts.setLearnedSelectionPolicy(softmaxFromMetadata);
        mcts.friendlyName = z ? "Biased MCTS" : "Biased MCTS (Uniform Playouts)";
        return mcts;
    }

    public static MCTS createBiasedMCTS(Features features2, boolean z) {
        SoftmaxPolicy softmaxPolicy = new SoftmaxPolicy(features2);
        MCTS mcts = new MCTS(new AG0Selection(), z ? softmaxPolicy : new RandomPlayout(200), new RobustChild());
        mcts.setLearnedSelectionPolicy(softmaxPolicy);
        mcts.friendlyName = z ? "Biased MCTS" : "Biased MCTS (Uniform Playouts)";
        return mcts;
    }

    public static MCTS createRegPolOptMCTS(Features features2, boolean z) {
        SoftmaxPolicy softmaxPolicy = new SoftmaxPolicy(features2);
        MCTS mcts = new MCTS(new SearchRegPolOpt(), z ? softmaxPolicy : new RandomPlayout(200), new ActRegPolOpt());
        mcts.setLearnedSelectionPolicy(softmaxPolicy);
        mcts.friendlyName = z ? "Biased MCTS (RegPolOpt)" : "Biased MCTS (RegPolOpt, Uniform Playouts)";
        return mcts;
    }

    public MCTS(SelectionStrategy selectionStrategy, PlayoutStrategy playoutStrategy, FinalMoveSelectionStrategy finalMoveSelectionStrategy) {
        this.backpropFlags = 0;
        this.selectionStrategy = selectionStrategy;
        this.playoutStrategy = playoutStrategy;
        this.backpropFlags = selectionStrategy.backpropFlags();
        this.backpropagation = new Backpropagation(this.backpropFlags);
        this.finalMoveSelectionStrategy = finalMoveSelectionStrategy;
    }

    @Override // util.AI
    public Move selectAction(Game game2, Context context, double d, int i, int i2) {
        long currentTimeMillis = System.currentTimeMillis();
        long j = d > 0.0d ? currentTimeMillis + ((long) (d * 1000.0d)) : Long.MAX_VALUE;
        int i3 = i >= 0 ? i : Integer.MAX_VALUE;
        int i4 = 0;
        if (this.treeReuse && this.rootNode != null) {
            List<Move> generateCompleteMovesList = context.trial().generateCompleteMovesList();
            int size = generateCompleteMovesList.size() - this.lastActionHistorySize;
            if (size < 0) {
                this.rootNode = null;
            }
            while (size > 0) {
                this.rootNode = this.rootNode.findChildForMove(generateCompleteMovesList.get(generateCompleteMovesList.size() - size));
                if (this.rootNode == null) {
                    break;
                }
                size--;
            }
        }
        if (this.rootNode == null || !this.treeReuse) {
            this.rootNode = createNode(this, null, null, null, context);
        } else {
            this.rootNode.setParent(null);
        }
        this.rootNode.rootInit(context);
        if (this.rootNode.numLegalMoves() == 1 && this.autoPlaySeconds >= 0.0d && this.autoPlaySeconds < d) {
            j = currentTimeMillis + ((long) (this.autoPlaySeconds * 1000.0d));
        }
        this.lastActionHistorySize = context.trial().numMoves();
        this.lastNumPlayoutActions = 0;
        while (i4 < i3 && System.currentTimeMillis() < j && !this.wantsInterrupt) {
            BaseNode baseNode = this.rootNode;
            baseNode.startNewIteration(context);
            while (true) {
                if (baseNode.contextRef().trial().status() != null) {
                    break;
                }
                int select = this.selectionStrategy.select(baseNode);
                BaseNode childForNthLegalMove = baseNode.childForNthLegalMove(select);
                Context traverse = baseNode.traverse(select);
                if (childForNthLegalMove == null) {
                    BaseNode createNode = createNode(this, baseNode, traverse.trial().lastMove(), baseNode.nthLegalMove(select), traverse);
                    baseNode.addChild(createNode, select);
                    baseNode = createNode;
                    baseNode.updateContextRef();
                    break;
                }
                baseNode = childForNthLegalMove;
                baseNode.updateContextRef();
            }
            Context playoutContext = baseNode.playoutContext();
            int i5 = 0;
            if (!baseNode.contextRef().trial().over()) {
                int numMoves = baseNode.contextRef().trial().numMoves();
                i5 = this.playoutStrategy.runPlayout(playoutContext).numMoves() - numMoves;
                this.lastNumPlayoutActions += playoutContext.trial().numMoves() - numMoves;
            }
            this.backpropagation.update(baseNode, playoutContext, AIUtils.agentUtilities(playoutContext), i5);
            i4++;
        }
        this.lastNumMctsIterations = i4;
        Move selectMove = this.finalMoveSelectionStrategy.selectMove(this.rootNode);
        if (this.wantsInterrupt) {
            this.analysisReport = null;
        } else {
            int i6 = -1;
            int i7 = 0;
            while (true) {
                if (i7 < this.rootNode.numLegalMoves()) {
                    BaseNode childForNthLegalMove2 = this.rootNode.childForNthLegalMove(i7);
                    if (childForNthLegalMove2 != null && this.rootNode.nthLegalMove(i7).equals(selectMove)) {
                        int mover = this.rootNode.deterministicContextRef().state().mover();
                        i6 = childForNthLegalMove2.numVisits();
                        this.lastReturnedMoveValueEst = childForNthLegalMove2.averageScore(mover, this.rootNode.deterministicContextRef().state());
                        break;
                    }
                    i7++;
                } else {
                    break;
                }
            }
            this.analysisReport = this.friendlyName + " made move after " + this.rootNode.numVisits() + " iterations (selected child visits = " + i6 + ", value = " + this.lastReturnedMoveValueEst + ").";
        }
        if (!this.preserveRootNode) {
            if (!this.treeReuse) {
                this.rootNode = null;
            } else if (!this.wantsInterrupt) {
                this.rootNode = this.rootNode.findChildForMove(selectMove);
                if (this.rootNode != null) {
                    this.rootNode.setParent(null);
                    this.lastActionHistorySize++;
                }
            }
        }
        return selectMove;
    }

    private BaseNode createNode(MCTS mcts, BaseNode baseNode, Move move, Move move2, Context context) {
        return (this.currentGameFlags & 64) == 0 ? new Node(mcts, baseNode, move, move2, context) : new OpenLoopNode(mcts, baseNode, move, move2, context.game());
    }

    public void setAutoPlaySeconds(double d) {
        this.autoPlaySeconds = d;
    }

    public void setTreeReuse(boolean z) {
        this.treeReuse = z;
    }

    public int backpropFlags() {
        return this.backpropFlags;
    }

    public SoftmaxPolicy learnedSelectionPolicy() {
        return this.learnedSelectionPolicy;
    }

    public PlayoutStrategy playoutStrategy() {
        return this.playoutStrategy;
    }

    public QInit qInit() {
        return this.qInit;
    }

    public BaseNode rootNode() {
        return this.rootNode;
    }

    public void setLearnedSelectionPolicy(SoftmaxPolicy softmaxPolicy) {
        this.learnedSelectionPolicy = softmaxPolicy;
    }

    public void setQInit(QInit qInit) {
        this.qInit = qInit;
    }

    public void setPreserveRootNode(boolean z) {
        this.preserveRootNode = z;
    }

    public int getNumMctsIterations() {
        return this.lastNumMctsIterations;
    }

    public int getNumPlayoutActions() {
        return this.lastNumPlayoutActions;
    }

    @Override // util.AI
    public void initAI(Game game2, int i) {
        this.currentGameFlags = game2.gameFlags();
        this.lastNumMctsIterations = -1;
        this.lastNumPlayoutActions = -1;
        this.rootNode = null;
        this.lastActionHistorySize = 0;
        if (this.learnedSelectionPolicy != null) {
            this.learnedSelectionPolicy.initAI(game2, i);
        }
        if ((this.playoutStrategy instanceof AI) && this.playoutStrategy != this.learnedSelectionPolicy) {
            ((AI) this.playoutStrategy).initAI(game2, i);
        }
        this.lastReturnedMoveValueEst = 0.0d;
        this.analysisReport = null;
    }

    @Override // util.AI
    public boolean supportsGame(Game game2) {
        if ((game2.gameFlags() & 1024) != 0) {
            return false;
        }
        if (this.learnedSelectionPolicy == null || this.learnedSelectionPolicy.supportsGame(game2)) {
            return this.playoutStrategy.playoutSupportsGame(game2);
        }
        return false;
    }

    @Override // util.AI
    public double estimateValue() {
        return this.lastReturnedMoveValueEst;
    }

    @Override // util.AI
    public String generateAnalysisReport() {
        return this.analysisReport;
    }

    @Override // util.AI
    public AI.AIVisualisationData aiVisualisationData() {
        if (this.rootNode == null || this.rootNode.numVisits() == 0) {
            return null;
        }
        int numLegalMoves = this.rootNode.numLegalMoves();
        FVector fVector = new FVector(numLegalMoves);
        FVector fVector2 = new FVector(numLegalMoves);
        int mover = this.rootNode.contextRef().state().mover();
        FastArrayList fastArrayList = new FastArrayList();
        for (int i = 0; i < numLegalMoves; i++) {
            BaseNode childForNthLegalMove = this.rootNode.childForNthLegalMove(i);
            if (childForNthLegalMove == null) {
                fVector.set(i, 0.0f);
                if (this.rootNode.numVisits() == 0) {
                    fVector2.set(i, 0.0f);
                } else {
                    fVector2.set(i, (float) this.rootNode.valueEstimateUnvisitedChildren(mover, this.rootNode.contextRef().state()));
                }
            } else {
                fVector.set(i, childForNthLegalMove.numVisits());
                fVector2.set(i, (float) childForNthLegalMove.averageScore(mover, this.rootNode.contextRef().state()));
            }
            if (fVector2.get(i) > 1.0f) {
                fVector2.set(i, 1.0f);
            } else if (fVector2.get(i) < -1.0f) {
                fVector2.set(i, -1.0f);
            }
            fastArrayList.add(this.rootNode.nthLegalMove(i));
        }
        return new AI.AIVisualisationData(fVector, fVector2, fastArrayList);
    }

    public static MCTS fromJson(JSONObject jSONObject) {
        MCTS mcts = new MCTS(SelectionStrategy.fromJson(jSONObject.getJSONObject("selection")), PlayoutStrategy.fromJson(jSONObject.getJSONObject("playout")), FinalMoveSelectionStrategy.fromJson(jSONObject.getJSONObject("final_move")));
        if (jSONObject.has("tree_reuse")) {
            mcts.setTreeReuse(jSONObject.getBoolean("tree_reuse"));
        }
        if (jSONObject.has("friendly_name")) {
            mcts.friendlyName = jSONObject.getString("friendly_name");
        }
        return mcts;
    }

    @Override // expert_iteration.ExpertPolicy
    public FastArrayList<Move> lastSearchRootMoves() {
        return this.rootNode.movesFromNode();
    }

    @Override // expert_iteration.ExpertPolicy
    public FVector computeExpertPolicy(double d) {
        return this.rootNode.computeVisitCountPolicy(d);
    }

    @Override // expert_iteration.ExpertPolicy
    public ExItExperience generateExItExperience() {
        return this.rootNode.generateExItExperience();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v98, types: [search.mcts.playout.PlayoutStrategy] */
    public static MCTS fromLines(String[] strArr) {
        SelectionStrategy ucb1 = new UCB1();
        RandomPlayout randomPlayout = new RandomPlayout(200);
        FinalMoveSelectionStrategy robustChild = new RobustChild();
        boolean z = false;
        SoftmaxPolicy softmaxPolicy = null;
        String str = "MCTS";
        for (String str2 : strArr) {
            String[] split = str2.split(SVGSyntax.COMMA);
            if (split[0].toLowerCase().startsWith("selection=")) {
                if (split[0].toLowerCase().endsWith("ucb1")) {
                    ucb1 = new UCB1();
                    ucb1.customise(split);
                } else if (split[0].toLowerCase().endsWith("ag0selection") || split[0].toLowerCase().endsWith("alphago0selection")) {
                    ucb1 = new AG0Selection();
                    ucb1.customise(split);
                } else {
                    System.err.println("Unknown selection strategy: " + str2);
                }
            } else if (split[0].toLowerCase().startsWith("playout=")) {
                randomPlayout = PlayoutStrategy.constructPlayoutStrategy(split);
            } else if (split[0].toLowerCase().startsWith("final_move=")) {
                if (split[0].toLowerCase().endsWith("maxavgscore")) {
                    robustChild = new MaxAvgScore();
                    robustChild.customise(split);
                } else if (split[0].toLowerCase().endsWith("robustchild")) {
                    robustChild = new RobustChild();
                    robustChild.customise(split);
                } else if (split[0].toLowerCase().endsWith("proportional") || split[0].toLowerCase().endsWith("proportionalexpvisitcount")) {
                    robustChild = new ProportionalExpVisitCount(1.0d);
                    robustChild.customise(split);
                } else {
                    System.err.println("Unknown final move selection strategy: " + str2);
                }
            } else if (split[0].toLowerCase().startsWith("tree_reuse=")) {
                if (split[0].toLowerCase().endsWith(SVGConstants.SVG_TRUE_VALUE)) {
                    z = true;
                } else if (split[0].toLowerCase().endsWith(SVGConstants.SVG_FALSE_VALUE)) {
                    z = false;
                } else {
                    System.err.println("Error in line: " + str2);
                }
            } else if (split[0].toLowerCase().startsWith("learned_selection_policy=")) {
                if (split[0].toLowerCase().endsWith("playout")) {
                    softmaxPolicy = (SoftmaxPolicy) randomPlayout;
                } else if (split[0].toLowerCase().endsWith("softmax") || split[0].toLowerCase().endsWith("softmaxplayout")) {
                    softmaxPolicy = new SoftmaxPolicy();
                    softmaxPolicy.customise(split);
                }
            } else if (split[0].toLowerCase().startsWith("friendly_name=")) {
                str = split[0].substring("friendly_name=".length());
            }
        }
        MCTS mcts = new MCTS(ucb1, randomPlayout, robustChild);
        mcts.setTreeReuse(z);
        mcts.setLearnedSelectionPolicy(softmaxPolicy);
        mcts.friendlyName = str;
        return mcts;
    }
}
