package search.mcts.selection;

import java.util.concurrent.ThreadLocalRandom;
import search.mcts.nodes.BaseNode;

/* loaded from: input_file:search/mcts/selection/UCB1GRAVE.class */
public class UCB1GRAVE implements SelectionStrategy {
    protected final int ref;
    protected final double bias;
    protected double explorationConstant;
    protected ThreadLocal<BaseNode> currentRefNode;

    public UCB1GRAVE() {
        this.currentRefNode = ThreadLocal.withInitial(() -> {
            return null;
        });
        this.ref = 100;
        this.bias = 1.0E-5d;
        this.explorationConstant = Math.sqrt(2.0d);
    }

    public UCB1GRAVE(int i, double d, double d2) {
        this.currentRefNode = ThreadLocal.withInitial(() -> {
            return null;
        });
        this.ref = i;
        this.bias = d;
        this.explorationConstant = d2;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public int select(BaseNode baseNode) {
        double averageScore;
        double d;
        double numVisits;
        double sqrt;
        int i = -1;
        double d2 = Double.NEGATIVE_INFINITY;
        double log = Math.log(Math.max(1, baseNode.sumLegalChildVisits()));
        int i2 = 0;
        int numLegalMoves = baseNode.numLegalMoves();
        int mover = baseNode.contextRef().state().mover();
        double valueEstimateUnvisitedChildren = baseNode.valueEstimateUnvisitedChildren(mover, baseNode.contextRef().state());
        if (this.currentRefNode.get() == null || baseNode.numVisits() > this.ref || baseNode.parent() == null) {
            this.currentRefNode.set(baseNode);
        }
        for (int i3 = 0; i3 < numLegalMoves; i3++) {
            BaseNode childForNthLegalMove = baseNode.childForNthLegalMove(i3);
            if (childForNthLegalMove == null) {
                averageScore = valueEstimateUnvisitedChildren;
                d = 0.0d;
                numVisits = 0.0d;
                sqrt = Math.sqrt(log);
            } else {
                averageScore = childForNthLegalMove.averageScore(mover, baseNode.contextRef().state());
                BaseNode.NodeStatistics graveStats = this.currentRefNode.get().graveStats(new BaseNode.MoveKey(childForNthLegalMove.parentMove(), baseNode.contextRef().trial().numMoves()));
                double d3 = graveStats.accumulatedScore;
                int i4 = graveStats.visitCount;
                d = d3 / i4;
                numVisits = i4 / ((i4 + r0) + ((this.bias * i4) * childForNthLegalMove.numVisits()));
                sqrt = Math.sqrt(log / childForNthLegalMove.numVisits());
            }
            double d4 = ((1.0d - numVisits) * averageScore) + (numVisits * d) + (this.explorationConstant * sqrt);
            if (d4 > d2) {
                d2 = d4;
                i = i3;
                i2 = 1;
            } else if (d4 == d2) {
                i2++;
                if (ThreadLocalRandom.current().nextInt() % i2 == 0) {
                    i = i3;
                }
            }
        }
        if (baseNode.childForNthLegalMove(i) == null) {
            this.currentRefNode.set(null);
        }
        return i;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public int backpropFlags() {
        return 1;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public void customise(String[] strArr) {
    }
}
