package search.mcts.nodes;

import com.itextpdf.text.pdf.ColumnText;
import game.Game;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import main.collections.FVector;
import main.collections.FastArrayList;
import other.context.Context;
import other.move.Move;
import other.state.State;
import policies.softmax.SoftmaxPolicyLinear;
import search.mcts.MCTS;
import training.expert_iteration.ExItExperience;

/* loaded from: input_file:search/mcts/nodes/BaseNode.class */
public abstract class BaseNode {
    protected BaseNode parent;
    protected final Move parentMove;
    protected final Move parentMoveWithoutConseq;
    protected final MCTS mcts;
    protected final double[] totalScores;
    protected final double[] sumSquaredScores;
    protected final Map<MCTS.MoveKey, NodeStatistics> graveStats;
    protected int numVisits = 0;
    protected AtomicInteger numVirtualVisits = new AtomicInteger();
    protected transient ReentrantLock nodeLock = new ReentrantLock();
    protected double[] heuristicValueEstimates = null;

    /* loaded from: input_file:search/mcts/nodes/BaseNode$NodeStatistics.class */
    public static class NodeStatistics {
        public int visitCount = 0;
        public double accumulatedScore = 0.0d;

        public String toString() {
            return "[visits = " + this.visitCount + ", accum. score = " + this.accumulatedScore + "]";
        }
    }

    public BaseNode(MCTS mcts, BaseNode baseNode, Move move, Move move2, Game game2) {
        this.mcts = mcts;
        this.parent = baseNode;
        this.parentMove = move;
        this.parentMoveWithoutConseq = move2;
        this.totalScores = new double[game2.players().count() + 1];
        this.sumSquaredScores = new double[game2.players().count() + 1];
        if ((mcts.backpropFlags() & 1) != 0) {
            this.graveStats = new ConcurrentHashMap();
        } else {
            this.graveStats = null;
        }
    }

    public abstract void addChild(BaseNode baseNode, int i);

    public abstract BaseNode childForNthLegalMove(int i);

    public abstract Context contextRef();

    public abstract Context deterministicContextRef();

    public abstract BaseNode findChildForMove(Move move);

    public abstract FVector learnedSelectionPolicy();

    public abstract FastArrayList<Move> movesFromNode();

    public abstract int nodeColour();

    public abstract Move nthLegalMove(int i);

    public abstract int numLegalMoves();

    public abstract Context playoutContext();

    public abstract void rootInit(Context context);

    public abstract void startNewIteration(Context context);

    public abstract int sumLegalChildVisits();

    public abstract Context traverse(int i);

    public abstract void updateContextRef();

    public abstract void cleanThreadLocals();

    public double expectedScore(int i) {
        if (this.numVisits == 0) {
            return 0.0d;
        }
        return (this.totalScores[i] - this.numVirtualVisits.get()) / (this.numVisits + this.numVirtualVisits.get());
    }

    public double exploitationScore(int i) {
        return expectedScore(i);
    }

    public boolean isValueProven(int i) {
        return false;
    }

    public double[] heuristicValueEstimates() {
        return this.heuristicValueEstimates;
    }

    public int numVisits() {
        return this.numVisits;
    }

    public int numVirtualVisits() {
        return this.numVirtualVisits.get();
    }

    public void addVirtualVisit() {
        this.numVirtualVisits.incrementAndGet();
    }

    public BaseNode parent() {
        return this.parent;
    }

    public Move parentMove() {
        return this.parentMove;
    }

    public void setNumVisits(int i) {
        this.numVisits = i;
    }

    public void setParent(BaseNode baseNode) {
        this.parent = baseNode;
    }

    public void setHeuristicValueEstimates(double[] dArr) {
        this.heuristicValueEstimates = dArr;
    }

    public double totalScore(int i) {
        return this.totalScores[i];
    }

    public double sumSquaredScores(int i) {
        return this.sumSquaredScores[i] + this.numVirtualVisits.get();
    }

    public void update(double[] dArr) {
        this.numVisits++;
        for (int i = 1; i < this.totalScores.length; i++) {
            double[] dArr2 = this.totalScores;
            int i2 = i;
            dArr2[i2] = dArr2[i2] + dArr[i];
            double[] dArr3 = this.sumSquaredScores;
            int i3 = i;
            dArr3[i3] = dArr3[i3] + (dArr[i] * dArr[i]);
        }
        this.numVirtualVisits.decrementAndGet();
    }

    public double valueEstimateUnvisitedChildren(int i) {
        switch (this.mcts.qInit()) {
            case DRAW:
                return 0.0d;
            case INF:
                return 10000.0d;
            case LOSS:
                return -1.0d;
            case PARENT:
                if (this.numVisits == 0) {
                    return 10000.0d;
                }
                return expectedScore(i);
            case WIN:
                return 1.0d;
            default:
                return 0.0d;
        }
    }

    public NodeStatistics getOrCreateGraveStatsEntry(MCTS.MoveKey moveKey) {
        NodeStatistics nodeStatistics = this.graveStats.get(moveKey);
        if (nodeStatistics == null) {
            nodeStatistics = new NodeStatistics();
            this.graveStats.put(moveKey, nodeStatistics);
        }
        return nodeStatistics;
    }

    public NodeStatistics graveStats(MCTS.MoveKey moveKey) {
        return this.graveStats.get(moveKey);
    }

    public FVector computeVisitCountPolicy(double d) {
        FVector fVector = new FVector(numLegalMoves());
        if (d == 0.0d) {
            int i = -1;
            TIntArrayList tIntArrayList = new TIntArrayList();
            for (int i2 = 0; i2 < numLegalMoves(); i2++) {
                BaseNode childForNthLegalMove = childForNthLegalMove(i2);
                int i3 = childForNthLegalMove == null ? 0 : childForNthLegalMove.numVisits;
                if (i3 > i) {
                    i = i3;
                    tIntArrayList.reset();
                    tIntArrayList.add(i2);
                } else if (i3 == i) {
                    tIntArrayList.add(i2);
                }
            }
            float size = 1.0f / tIntArrayList.size();
            for (int i4 = 0; i4 < tIntArrayList.size(); i4++) {
                fVector.set(tIntArrayList.getQuick(i4), size);
            }
        } else {
            for (int i5 = 0; i5 < numLegalMoves(); i5++) {
                fVector.set(i5, childForNthLegalMove(i5) == null ? 0 : r0.numVisits);
            }
            if (d != 1.0d) {
                fVector.raiseToPower(1.0d / d);
            }
            if (fVector.sum() > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                fVector.mult(1.0f / fVector.sum());
            }
        }
        return fVector;
    }

    public double normalisedEntropy() {
        FVector computeVisitCountPolicy = computeVisitCountPolicy(1.0d);
        int dim = computeVisitCountPolicy.dim();
        if (dim <= 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dim; i++) {
            float f = computeVisitCountPolicy.get(i);
            if (f > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                d -= f * Math.log(f);
            }
        }
        return d / Math.log(dim);
    }

    public double learnedSelectionPolicyNormalisedEntropy() {
        FVector learnedSelectionPolicy = learnedSelectionPolicy();
        int dim = learnedSelectionPolicy.dim();
        if (dim <= 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dim; i++) {
            float f = learnedSelectionPolicy.get(i);
            if (f > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                d -= f * Math.log(f);
            }
        }
        return d / Math.log(dim);
    }

    public double learnedPlayoutPolicyNormalisedEntropy() {
        FVector computeDistribution = ((SoftmaxPolicyLinear) this.mcts.playoutStrategy()).computeDistribution(contextRef(), contextRef().game().moves(contextRef()).moves(), true);
        int dim = computeDistribution.dim();
        if (dim <= 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dim; i++) {
            float f = computeDistribution.get(i);
            if (f > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                d -= f * Math.log(f);
            }
        }
        return d / Math.log(dim);
    }

    public ExItExperience generateExItExperience(float f) {
        FastArrayList fastArrayList = new FastArrayList(numLegalMoves());
        float[] fArr = new float[numLegalMoves()];
        State state = deterministicContextRef().state();
        for (int i = 0; i < numLegalMoves(); i++) {
            BaseNode childForNthLegalMove = childForNthLegalMove(i);
            Move move = new Move(nthLegalMove(i));
            move.setMover(nthLegalMove(i).mover());
            move.then().clear();
            fastArrayList.add(move);
            if (childForNthLegalMove == null) {
                fArr[i] = -1.0f;
            } else {
                fArr[i] = (float) childForNthLegalMove.expectedScore(state.playerToAgent(state.mover()));
            }
        }
        FVector computeVisitCountPolicy = computeVisitCountPolicy(1.0d);
        float min = computeVisitCountPolicy.min();
        boolean z = true;
        for (int i2 = 0; i2 < numLegalMoves(); i2++) {
            BaseNode childForNthLegalMove2 = childForNthLegalMove(i2);
            if (childForNthLegalMove2 == null || !(childForNthLegalMove2 instanceof ScoreBoundsNode)) {
                z = false;
            } else if (((ScoreBoundsNode) childForNthLegalMove2).isPruned()) {
                computeVisitCountPolicy.set(i2, min);
            } else {
                z = false;
            }
        }
        if (z) {
            computeVisitCountPolicy = computeVisitCountPolicy(1.0d);
        } else {
            computeVisitCountPolicy.normalise();
        }
        return new ExItExperience(new Context(deterministicContextRef()), new ExItExperience.ExItExperienceState(deterministicContextRef()), fastArrayList, computeVisitCountPolicy, FVector.wrap(fArr), f);
    }

    public List<ExItExperience> generateExItExperiences() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(generateExItExperience(1.0f));
        return arrayList;
    }

    public ReentrantLock getLock() {
        return this.nodeLock;
    }
}
