package search.mcts.nodes;

import game.Game;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import main.collections.FVector;
import main.collections.FastArrayList;
import other.context.Context;
import other.move.Move;
import policies.softmax.SoftmaxPolicy;
import search.mcts.MCTS;

/* loaded from: input_file:search/mcts/nodes/OpenLoopNode.class */
public final class OpenLoopNode extends BaseNode {
    protected final List<OpenLoopNode> children;
    protected ThreadLocal<Context> currentItContext;
    protected Context deterministicContext;
    protected FastArrayList<Move> rootLegalMovesList;
    protected ThreadLocal<FastArrayList<Move>> currentLegalMoves;
    protected ThreadLocal<FVector> learnedSelectionPolicy;
    protected FVector rootLearnedSelectionPolicy;
    protected ThreadLocal<OpenLoopNode[]> moveIdxToNode;
    protected OpenLoopNode[] rootMoveIdxToNode;
    protected ThreadLocal<Float> logit;

    public OpenLoopNode(MCTS mcts, BaseNode baseNode, Move move, Move move2, Game game2) {
        super(mcts, baseNode, move, move2, game2);
        this.children = new ArrayList(10);
        this.currentItContext = ThreadLocal.withInitial(() -> {
            return null;
        });
        this.deterministicContext = null;
        this.rootLegalMovesList = null;
        this.currentLegalMoves = ThreadLocal.withInitial(() -> {
            return null;
        });
        this.learnedSelectionPolicy = ThreadLocal.withInitial(() -> {
            return null;
        });
        this.rootLearnedSelectionPolicy = null;
        this.moveIdxToNode = ThreadLocal.withInitial(() -> {
            return null;
        });
        this.rootMoveIdxToNode = null;
        this.logit = ThreadLocal.withInitial(() -> {
            return Float.valueOf(Float.NaN);
        });
    }

    @Override // search.mcts.nodes.BaseNode
    public void addChild(BaseNode baseNode, int i) {
        this.children.add((OpenLoopNode) baseNode);
        if (parent() != null || this.deterministicContext == null) {
            return;
        }
        updateLegalMoveDependencies(true);
    }

    @Override // search.mcts.nodes.BaseNode
    public OpenLoopNode childForNthLegalMove(int i) {
        return this.rootMoveIdxToNode != null ? this.rootMoveIdxToNode[i] : this.moveIdxToNode.get()[i];
    }

    @Override // search.mcts.nodes.BaseNode
    public Context contextRef() {
        return this.currentItContext.get();
    }

    @Override // search.mcts.nodes.BaseNode
    public Context deterministicContextRef() {
        return this.deterministicContext;
    }

    @Override // search.mcts.nodes.BaseNode
    public OpenLoopNode findChildForMove(Move move) {
        OpenLoopNode openLoopNode = null;
        Iterator<OpenLoopNode> it = this.children.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            OpenLoopNode next = it.next();
            if (next.parentMove().equals(move)) {
                openLoopNode = next;
                break;
            }
        }
        return openLoopNode;
    }

    @Override // search.mcts.nodes.BaseNode
    public FVector learnedSelectionPolicy() {
        return this.rootLearnedSelectionPolicy != null ? this.rootLearnedSelectionPolicy : this.learnedSelectionPolicy.get();
    }

    @Override // search.mcts.nodes.BaseNode
    public FastArrayList<Move> movesFromNode() {
        return this.rootLegalMovesList != null ? this.rootLegalMovesList : this.currentLegalMoves.get();
    }

    @Override // search.mcts.nodes.BaseNode
    public int nodeColour() {
        return 0;
    }

    @Override // search.mcts.nodes.BaseNode
    public Move nthLegalMove(int i) {
        return movesFromNode().get(i);
    }

    @Override // search.mcts.nodes.BaseNode
    public int numLegalMoves() {
        return movesFromNode().size();
    }

    @Override // search.mcts.nodes.BaseNode
    public Context playoutContext() {
        return this.currentItContext.get();
    }

    @Override // search.mcts.nodes.BaseNode
    public void rootInit(Context context) {
        this.deterministicContext = context;
        this.currentItContext.set(this.mcts.copyContext(context));
        updateLegalMoveDependencies(true);
    }

    @Override // search.mcts.nodes.BaseNode
    public void startNewIteration(Context context) {
        this.currentItContext.set(this.mcts.copyContext(context));
    }

    @Override // search.mcts.nodes.BaseNode
    public int sumLegalChildVisits() {
        int i = 0;
        for (int i2 = 0; i2 < numLegalMoves(); i2++) {
            OpenLoopNode childForNthLegalMove = childForNthLegalMove(i2);
            if (childForNthLegalMove != null) {
                i += childForNthLegalMove.numVisits;
            }
        }
        return i;
    }

    @Override // search.mcts.nodes.BaseNode
    public Context traverse(int i) {
        Context context = this.currentItContext.get();
        context.game().apply(context, movesFromNode().get(i));
        return context;
    }

    @Override // search.mcts.nodes.BaseNode
    public void updateContextRef() {
        if (this.parent != null) {
            this.currentItContext.set(this.parent.contextRef());
            updateLegalMoveDependencies(false);
        }
    }

    @Override // search.mcts.nodes.BaseNode
    public void cleanThreadLocals() {
        this.currentItContext.remove();
        this.currentLegalMoves.remove();
        this.learnedSelectionPolicy.remove();
        this.moveIdxToNode.remove();
        this.logit.remove();
        getLock().lock();
        try {
            Iterator<OpenLoopNode> it = this.children.iterator();
            while (it.hasNext()) {
                it.next().cleanThreadLocals();
            }
        } finally {
            getLock().unlock();
        }
    }

    private void updateLegalMoveDependencies(boolean z) {
        FastArrayList<Move> fastArrayList;
        getLock().lock();
        try {
            Context context = z ? this.deterministicContext : this.currentItContext.get();
            if (z) {
                this.rootLegalMovesList = new FastArrayList<>(context.game().moves(context).moves());
                this.currentLegalMoves.set(null);
                fastArrayList = this.rootLegalMovesList;
            } else {
                fastArrayList = new FastArrayList<>(context.game().moves(context).moves());
                this.currentLegalMoves.set(fastArrayList);
            }
            if (z) {
                for (int size = this.children.size() - 1; size >= 0; size--) {
                    if (!fastArrayList.contains(this.children.get(size).parentMoveWithoutConseq)) {
                        this.children.remove(size).cleanThreadLocals();
                    }
                }
            }
            OpenLoopNode[] openLoopNodeArr = new OpenLoopNode[fastArrayList.size()];
            if (z) {
                this.rootMoveIdxToNode = openLoopNodeArr;
                this.moveIdxToNode.set(null);
            } else {
                this.moveIdxToNode.set(openLoopNodeArr);
            }
            for (int i = 0; i < openLoopNodeArr.length; i++) {
                Move move = fastArrayList.get(i);
                int i2 = 0;
                while (true) {
                    if (i2 >= this.children.size()) {
                        break;
                    }
                    if (move.equals(this.children.get(i2).parentMoveWithoutConseq)) {
                        openLoopNodeArr[i] = this.children.get(i2);
                        break;
                    }
                    i2++;
                }
            }
            if (this.mcts.learnedSelectionPolicy() != null) {
                float[] fArr = new float[openLoopNodeArr.length];
                for (int i3 = 0; i3 < fArr.length; i3++) {
                    if (openLoopNodeArr[i3] == null || Float.isNaN(openLoopNodeArr[i3].logit.get().floatValue())) {
                        fArr[i3] = this.mcts.learnedSelectionPolicy().computeLogit(context, fastArrayList.get(i3));
                        if (openLoopNodeArr[i3] != null) {
                            openLoopNodeArr[i3].logit.set(Float.valueOf(fArr[i3]));
                        }
                    } else {
                        fArr[i3] = openLoopNodeArr[i3].logit.get().floatValue();
                    }
                }
                FVector wrap = FVector.wrap(fArr);
                if (this.mcts.learnedSelectionPolicy() instanceof SoftmaxPolicy) {
                    wrap.softmax();
                } else {
                    wrap.normalise();
                }
                if (z) {
                    this.rootLearnedSelectionPolicy = wrap;
                    this.learnedSelectionPolicy.set(null);
                } else {
                    this.learnedSelectionPolicy.set(wrap);
                }
            }
        } finally {
            getLock().unlock();
        }
    }
}
