package search.mcts.utils;

import main.collections.FVector;
import search.mcts.nodes.BaseNode;
import util.state.State;

/* loaded from: input_file:search/mcts/utils/RegPolOptMCTS.class */
public class RegPolOptMCTS {
    private RegPolOptMCTS() {
    }

    public static FVector computePiBar(BaseNode baseNode, double d) {
        State state = baseNode.contextRef().state();
        int numLegalMoves = baseNode.numLegalMoves();
        double computeLambdaMultiplier = computeLambdaMultiplier(baseNode, d);
        if (computeLambdaMultiplier == 0.0d) {
            computeLambdaMultiplier = 1.0d;
        }
        FVector learnedSelectionPolicy = baseNode.learnedSelectionPolicy();
        FVector fVector = new FVector(numLegalMoves);
        double d2 = -10000.0d;
        double d3 = -10000.0d;
        for (int i = 0; i < numLegalMoves; i++) {
            double valueEstimateUnvisitedChildren = baseNode.childForNthLegalMove(i) == null ? baseNode.valueEstimateUnvisitedChildren(state.mover(), state) : baseNode.averageScore(state.mover(), state);
            fVector.set(i, (float) valueEstimateUnvisitedChildren);
            d2 = Math.max(d2, valueEstimateUnvisitedChildren + (computeLambdaMultiplier * learnedSelectionPolicy.get(i)));
            d3 = Math.max(d3, valueEstimateUnvisitedChildren + computeLambdaMultiplier);
        }
        double alphaStarBinarySearch = alphaStarBinarySearch(d2, d3, computeLambdaMultiplier, learnedSelectionPolicy, fVector);
        FVector fVector2 = new FVector(numLegalMoves);
        for (int i2 = 0; i2 < numLegalMoves; i2++) {
            fVector2.set(i2, (float) ((computeLambdaMultiplier * learnedSelectionPolicy.get(i2)) / (alphaStarBinarySearch - fVector.get(i2))));
        }
        fVector2.div(fVector2.sum());
        return fVector2;
    }

    public static FVector computePiHat(BaseNode baseNode) {
        int numLegalMoves = baseNode.numLegalMoves();
        FVector fVector = new FVector(numLegalMoves);
        float numVisits = baseNode.numVisits() + numLegalMoves;
        for (int i = 0; i < numLegalMoves; i++) {
            fVector.set(i, (baseNode.childForNthLegalMove(i) == null ? 1 : r0.numVisits() + 1) / numVisits);
        }
        return fVector;
    }

    public static double computeLambdaMultiplier(BaseNode baseNode, double d) {
        return d * (Math.sqrt(baseNode.numVisits()) / (baseNode.numLegalMoves() + r0));
    }

    private static double alphaStarBinarySearch(double d, double d2, double d3, FVector fVector, FVector fVector2) {
        double d4 = (d + d2) / 2.0d;
        double d5 = 0.0d;
        for (int i = 0; i < fVector.dim(); i++) {
            d5 += (d3 * fVector.get(i)) / (d4 - fVector2.get(i));
        }
        return Math.abs(d5 - 1.0d) < 1.0E-4d ? d4 : d5 < 1.0d ? alphaStarBinarySearch(d, d4, d3, fVector, fVector2) : alphaStarBinarySearch(d4, d2, d3, fVector, fVector2);
    }
}
