package search.mcts.selection;

import java.util.concurrent.ThreadLocalRandom;
import main.collections.FVector;
import search.mcts.nodes.BaseNode;

/* loaded from: input_file:search/mcts/selection/ExItSelection.class */
public final class ExItSelection implements SelectionStrategy {
    protected double explorationConstant;
    protected double priorPolicyWeight;

    public ExItSelection(double d) {
        this(Math.sqrt(2.0d), d);
    }

    public ExItSelection(double d, double d2) {
        this.explorationConstant = d;
        this.priorPolicyWeight = d2;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public int select(BaseNode baseNode) {
        double averageScore;
        int numVisits;
        double sqrt;
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        FVector learnedSelectionPolicy = baseNode.learnedSelectionPolicy();
        double log = Math.log(Math.max(1, baseNode.sumLegalChildVisits()));
        int i2 = 0;
        int numLegalMoves = baseNode.numLegalMoves();
        int mover = baseNode.contextRef().state().mover();
        double valueEstimateUnvisitedChildren = baseNode.valueEstimateUnvisitedChildren(mover, baseNode.contextRef().state());
        for (int i3 = 0; i3 < numLegalMoves; i3++) {
            BaseNode childForNthLegalMove = baseNode.childForNthLegalMove(i3);
            if (childForNthLegalMove == null) {
                averageScore = valueEstimateUnvisitedChildren;
                numVisits = 0;
                sqrt = Math.sqrt(log);
            } else {
                averageScore = childForNthLegalMove.averageScore(mover, baseNode.contextRef().state());
                numVisits = childForNthLegalMove.numVisits();
                sqrt = Math.sqrt(log / numVisits);
            }
            double d2 = averageScore + (this.explorationConstant * sqrt) + (this.priorPolicyWeight * (learnedSelectionPolicy.get(i3) / (numVisits + 1)));
            if (d2 > d) {
                d = d2;
                i = i3;
                i2 = 1;
            } else if (d2 == d) {
                i2++;
                if (ThreadLocalRandom.current().nextInt() % i2 == 0) {
                    i = i3;
                }
            }
        }
        return i;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public int backpropFlags() {
        return 0;
    }

    @Override // search.mcts.selection.SelectionStrategy
    public void customise(String[] strArr) {
        if (strArr.length > 1) {
            for (int i = 1; i < strArr.length; i++) {
                String str = strArr[i];
                if (str.startsWith("explorationconstant=")) {
                    this.explorationConstant = Double.parseDouble(str.substring("explorationconstant=".length()));
                } else {
                    System.err.println("ExItSelection ignores unknown customization: " + str);
                }
            }
        }
    }
}
