package search.mcts.nodes;

import expert_iteration.ExItExperience;
import game.Game;
import gnu.trove.list.array.TIntArrayList;
import java.util.HashMap;
import java.util.Map;
import main.collections.FVector;
import main.collections.FastArrayList;
import policies.softmax.SoftmaxPolicy;
import search.mcts.MCTS;
import util.Context;
import util.Move;
import util.state.State;

/* loaded from: input_file:search/mcts/nodes/BaseNode.class */
public abstract class BaseNode {
    protected BaseNode parent;
    protected final Move parentMove;
    protected final Move parentMoveWithoutConseq;
    protected final MCTS mcts;
    protected int numVisits = 0;
    protected final double[] totalScores;
    protected final Map<MoveKey, NodeStatistics> graveStats;

    /* loaded from: input_file:search/mcts/nodes/BaseNode$MoveKey.class */
    public static class MoveKey {
        public final Move move;
        private final int cachedHashCode;

        public MoveKey(Move move, int i) {
            int nonDecision;
            this.move = move;
            if (move.isPass()) {
                nonDecision = (31 * 1) + i + 1297;
            } else if (move.isSwap()) {
                nonDecision = (31 * 1) + i + 587;
            } else {
                nonDecision = (31 * (!move.isOrientedMove() ? (31 * 1) + move.toNonDecision() + move.fromNonDecision() : (31 * ((31 * 1) + move.toNonDecision())) + move.fromNonDecision())) + move.stateNonDecision();
            }
            this.cachedHashCode = (31 * nonDecision) + move.mover();
        }

        public int hashCode() {
            return this.cachedHashCode;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof MoveKey)) {
                return false;
            }
            MoveKey moveKey = (MoveKey) obj;
            if (this.move == null) {
                return moveKey.move == null;
            }
            if (this.move.isOrientedMove() != moveKey.move.isOrientedMove()) {
                return false;
            }
            if (!this.move.isOrientedMove()) {
                boolean z = false;
                if ((this.move.toNonDecision() == moveKey.move.toNonDecision() && this.move.fromNonDecision() == moveKey.move.fromNonDecision()) || (this.move.toNonDecision() == moveKey.move.fromNonDecision() && this.move.fromNonDecision() == moveKey.move.toNonDecision())) {
                    z = true;
                }
                if (!z) {
                    return false;
                }
            } else if (this.move.toNonDecision() != moveKey.move.toNonDecision() || this.move.fromNonDecision() != moveKey.move.fromNonDecision()) {
                return false;
            }
            return this.move.mover() == moveKey.move.mover() && this.move.stateNonDecision() == moveKey.move.stateNonDecision();
        }

        public String toString() {
            return "[Move = " + this.move + ", Hash = " + this.cachedHashCode + "]";
        }
    }

    /* loaded from: input_file:search/mcts/nodes/BaseNode$NodeStatistics.class */
    public static class NodeStatistics {
        public int visitCount = 0;
        public double accumulatedScore = 0.0d;

        public String toString() {
            return "[visits = " + this.visitCount + ", accum. score = " + this.accumulatedScore + "]";
        }
    }

    public BaseNode(MCTS mcts, BaseNode baseNode, Move move, Move move2, Game game2) {
        this.mcts = mcts;
        this.parent = baseNode;
        this.parentMove = move;
        this.parentMoveWithoutConseq = move2;
        this.totalScores = new double[game2.players().count() + 1];
        if ((mcts.backpropFlags() & 1) != 0) {
            this.graveStats = new HashMap();
        } else {
            this.graveStats = null;
        }
    }

    public abstract void addChild(BaseNode baseNode, int i);

    public abstract BaseNode childForNthLegalMove(int i);

    public abstract Context contextRef();

    public abstract Context deterministicContextRef();

    public abstract BaseNode findChildForMove(Move move);

    public abstract FVector learnedSelectionPolicy();

    public abstract FastArrayList<Move> movesFromNode();

    public abstract int nodeColour();

    public abstract Move nthLegalMove(int i);

    public abstract int numLegalMoves();

    public abstract Context playoutContext();

    public abstract void rootInit(Context context);

    public abstract void startNewIteration(Context context);

    public abstract int sumLegalChildVisits();

    public abstract Context traverse(int i);

    public abstract void updateContextRef();

    public double averageScore(int i, State state) {
        if (this.numVisits == 0) {
            return 0.0d;
        }
        return this.totalScores[state.playerToAgent(i)] / this.numVisits;
    }

    public int numVisits() {
        return this.numVisits;
    }

    public BaseNode parent() {
        return this.parent;
    }

    public Move parentMove() {
        return this.parentMove;
    }

    public void setNumVisits(int i) {
        this.numVisits = i;
    }

    public void setParent(BaseNode baseNode) {
        this.parent = baseNode;
    }

    public double totalScore(int i) {
        return this.totalScores[i];
    }

    public void update(double[] dArr) {
        this.numVisits++;
        for (int i = 1; i < this.totalScores.length; i++) {
            double[] dArr2 = this.totalScores;
            int i2 = i;
            dArr2[i2] = dArr2[i2] + dArr[i];
        }
    }

    public double valueEstimateUnvisitedChildren(int i, State state) {
        switch (this.mcts.qInit()) {
            case DRAW:
                return 0.0d;
            case INF:
                return 10000.0d;
            case LOSS:
                return -1.0d;
            case PARENT:
                if (this.numVisits == 0) {
                    return 10000.0d;
                }
                return averageScore(i, state);
            case WIN:
                return 1.0d;
            default:
                return 0.0d;
        }
    }

    public NodeStatistics getOrCreateGraveStatsEntry(MoveKey moveKey) {
        NodeStatistics nodeStatistics = this.graveStats.get(moveKey);
        if (nodeStatistics == null) {
            nodeStatistics = new NodeStatistics();
            this.graveStats.put(moveKey, nodeStatistics);
        }
        return nodeStatistics;
    }

    public NodeStatistics graveStats(MoveKey moveKey) {
        return this.graveStats.get(moveKey);
    }

    public FVector computeVisitCountPolicy(double d) {
        FVector fVector = new FVector(numLegalMoves());
        if (d == 0.0d) {
            int i = -1;
            TIntArrayList tIntArrayList = new TIntArrayList();
            for (int i2 = 0; i2 < numLegalMoves(); i2++) {
                BaseNode childForNthLegalMove = childForNthLegalMove(i2);
                int i3 = childForNthLegalMove == null ? 0 : childForNthLegalMove.numVisits;
                if (i3 > i) {
                    i = i3;
                    tIntArrayList.reset();
                    tIntArrayList.add(i2);
                } else if (i3 == i) {
                    tIntArrayList.add(i2);
                }
            }
            float size = 1.0f / tIntArrayList.size();
            for (int i4 = 0; i4 < tIntArrayList.size(); i4++) {
                fVector.set(tIntArrayList.getQuick(i4), size);
            }
        } else {
            for (int i5 = 0; i5 < numLegalMoves(); i5++) {
                fVector.set(i5, childForNthLegalMove(i5) == null ? 0 : r0.numVisits);
            }
            if (d != 1.0d) {
                fVector.raiseToPower(1.0d / d);
            }
            if (fVector.sum() > 0.0f) {
                fVector.mult(1.0f / fVector.sum());
            }
        }
        return fVector;
    }

    public double normalisedEntropy() {
        FVector computeVisitCountPolicy = computeVisitCountPolicy(1.0d);
        int dim = computeVisitCountPolicy.dim();
        if (dim <= 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dim; i++) {
            float f = computeVisitCountPolicy.get(i);
            if (f > 0.0f) {
                d -= f * Math.log(f);
            }
        }
        return d / Math.log(dim);
    }

    public double learnedSelectionPolicyNormalisedEntropy() {
        FVector learnedSelectionPolicy = learnedSelectionPolicy();
        int dim = learnedSelectionPolicy.dim();
        if (dim <= 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dim; i++) {
            float f = learnedSelectionPolicy.get(i);
            if (f > 0.0f) {
                d -= f * Math.log(f);
            }
        }
        return d / Math.log(dim);
    }

    public double learnedPlayoutPolicyNormalisedEntropy() {
        FVector computeDistribution = ((SoftmaxPolicy) this.mcts.playoutStrategy()).computeDistribution(contextRef(), contextRef().game().moves(contextRef()).moves(), true);
        int dim = computeDistribution.dim();
        if (dim <= 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dim; i++) {
            float f = computeDistribution.get(i);
            if (f > 0.0f) {
                d -= f * Math.log(f);
            }
        }
        return d / Math.log(dim);
    }

    public ExItExperience generateExItExperience() {
        FastArrayList fastArrayList = new FastArrayList(numLegalMoves());
        float[] fArr = new float[numLegalMoves()];
        for (int i = 0; i < numLegalMoves(); i++) {
            BaseNode childForNthLegalMove = childForNthLegalMove(i);
            Move move = new Move(nthLegalMove(i));
            move.then().clear();
            fastArrayList.add(move);
            if (childForNthLegalMove == null) {
                fArr[i] = -1.0f;
            } else {
                fArr[i] = (float) childForNthLegalMove.averageScore(deterministicContextRef().state().mover(), deterministicContextRef().state());
            }
        }
        return new ExItExperience(new ExItExperience.ExItExperienceState(deterministicContextRef()), fastArrayList, computeVisitCountPolicy(1.0d), FVector.wrap(fArr));
    }
}
