package search.mcts.playout;

import game.Game;
import java.util.concurrent.ThreadLocalRandom;
import main.collections.FVector;
import main.collections.FastArrayList;
import other.context.Context;
import other.move.Move;
import other.playout.PlayoutMoveSelector;
import other.trial.Trial;
import playout_move_selectors.EpsilonGreedyWrapper;
import search.mcts.MCTS;

/* loaded from: input_file:search/mcts/playout/MAST.class */
public class MAST implements PlayoutStrategy {
    protected int playoutTurnLimit;
    protected double epsilon;
    protected ThreadLocal<MASTMoveSelector> moveSelector;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:search/mcts/playout/MAST$MASTMoveSelector.class */
    public static class MASTMoveSelector extends PlayoutMoveSelector {
        protected MCTS mcts = null;

        protected MASTMoveSelector() {
        }

        @Override // other.playout.PlayoutMoveSelector
        public Move selectMove(Context context, FastArrayList<Move> fastArrayList, int i, PlayoutMoveSelector.IsMoveReallyLegal isMoveReallyLegal) {
            FVector fVector = new FVector(fastArrayList.size());
            for (int i2 = 0; i2 < fastArrayList.size(); i2++) {
                MCTS.ActionStatistics orCreateActionStatsEntry = this.mcts.getOrCreateActionStatsEntry(new MCTS.MoveKey(fastArrayList.get(i2), context.trial().numMoves()));
                if (orCreateActionStatsEntry.visitCount > 0.0d) {
                    fVector.set(i2, orCreateActionStatsEntry.accumulatedScore / orCreateActionStatsEntry.visitCount);
                } else {
                    fVector.set(i2, 1.0f);
                }
            }
            int size = fastArrayList.size();
            while (size > 0) {
                size--;
                int argMaxRand = fVector.argMaxRand();
                Move move = fastArrayList.get(argMaxRand);
                if (isMoveReallyLegal.checkMove(move)) {
                    return move;
                }
                fVector.set(argMaxRand, Float.NEGATIVE_INFINITY);
            }
            return null;
        }
    }

    public MAST() {
        this.playoutTurnLimit = -1;
        this.epsilon = 0.1d;
        this.moveSelector = ThreadLocal.withInitial(() -> {
            return new MASTMoveSelector();
        });
        this.playoutTurnLimit = -1;
    }

    public MAST(int i, double d) {
        this.playoutTurnLimit = -1;
        this.epsilon = 0.1d;
        this.moveSelector = ThreadLocal.withInitial(() -> {
            return new MASTMoveSelector();
        });
        this.playoutTurnLimit = i;
        this.epsilon = d;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public Trial runPlayout(MCTS mcts, Context context) {
        MASTMoveSelector mASTMoveSelector = this.moveSelector.get();
        mASTMoveSelector.mcts = mcts;
        Trial playout = context.game().playout(context, null, 1.0d, new EpsilonGreedyWrapper(mASTMoveSelector, this.epsilon), -1, this.playoutTurnLimit, ThreadLocalRandom.current());
        mASTMoveSelector.mcts = null;
        return playout;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public int backpropFlags() {
        return 2;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public boolean playoutSupportsGame(Game game2) {
        return !game2.isDeductionPuzzle() || playoutTurnLimit() > 0;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public void customise(String[] strArr) {
        for (int i = 1; i < strArr.length; i++) {
            String str = strArr[i];
            if (str.toLowerCase().startsWith("playoutturnlimit=")) {
                this.playoutTurnLimit = Integer.parseInt(str.substring("playoutturnlimit=".length()));
            }
        }
    }

    public int playoutTurnLimit() {
        return this.playoutTurnLimit;
    }
}
