package training.policy_gradients;

import features.FeatureVector;
import features.feature_sets.BaseFeatureSet;
import features.feature_sets.network.JITSPatterNetFeatureSet;
import features.spatial.FeatureUtils;
import game.Game;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TLongArrayList;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import main.DaemonThreadFactory;
import main.collections.FVector;
import main.collections.FastArrayList;
import main.collections.ListUtils;
import optimisers.Optimiser;
import other.RankUtils;
import other.context.Context;
import other.move.Move;
import other.state.State;
import other.trial.Trial;
import policies.softmax.SoftmaxPolicyLinear;
import training.ExperienceSample;
import training.expert_iteration.params.FeatureDiscoveryParams;
import training.expert_iteration.params.ObjectiveParams;
import training.expert_iteration.params.TrainingParams;
import training.feature_discovery.FeatureSetExpander;
import utils.ExponentialMovingAverage;
import utils.experiments.InterruptableExperiment;

/* loaded from: input_file:training/policy_gradients/Reinforce.class */
public class Reinforce {
    private static final double EXPERIENCE_DISCOUNT_THRESHOLD = 0.001d;
    private static final int DATA_PER_TRIAL_THRESHOLD = 50;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:training/policy_gradients/Reinforce$PGExperience.class */
    public static class PGExperience extends ExperienceSample {
        protected final State state;
        protected final int lastFromPos;
        protected final int lastToPos;
        protected final FastArrayList<Move> legalMoves;
        protected final FeatureVector[] featureVectors;
        protected final int movePlayedIdx;
        protected final float returns;
        protected final double discountMultiplier;

        public PGExperience(State state, Move move, FastArrayList<Move> fastArrayList, FeatureVector[] featureVectorArr, int i, float f, double d) {
            this.state = state;
            this.lastFromPos = FeatureUtils.fromPos(move);
            this.lastToPos = FeatureUtils.toPos(move);
            this.legalMoves = fastArrayList;
            this.featureVectors = featureVectorArr;
            this.movePlayedIdx = i;
            this.returns = f;
            this.discountMultiplier = d;
        }

        public FeatureVector[] featureVectors() {
            return this.featureVectors;
        }

        public int movePlayedIdx() {
            return this.movePlayedIdx;
        }

        public float returns() {
            return this.returns;
        }

        public double discountMultiplier() {
            return this.discountMultiplier;
        }

        @Override // training.ExperienceSample
        public FeatureVector[] generateFeatureVectors(BaseFeatureSet baseFeatureSet) {
            return this.featureVectors;
        }

        @Override // training.ExperienceSample
        public FVector expertDistribution() {
            FVector fVector = new FVector(this.featureVectors.length);
            fVector.set(this.movePlayedIdx, (float) (this.returns * this.discountMultiplier));
            fVector.softmax();
            return fVector;
        }

        @Override // training.ExperienceSample
        public State gameState() {
            return this.state;
        }

        @Override // training.ExperienceSample
        public int lastFromPos() {
            return this.lastFromPos;
        }

        @Override // training.ExperienceSample
        public int lastToPos() {
            return this.lastToPos;
        }

        @Override // training.ExperienceSample
        public FastArrayList<Move> moves() {
            return this.legalMoves;
        }

        @Override // training.ExperienceSample
        public BitSet winningMoves() {
            return new BitSet();
        }

        @Override // training.ExperienceSample
        public BitSet losingMoves() {
            return new BitSet();
        }

        @Override // training.ExperienceSample
        public BitSet antiDefeatingMoves() {
            return new BitSet();
        }
    }

    public static BaseFeatureSet[] runSelfPlayPG(Game game2, SoftmaxPolicyLinear softmaxPolicyLinear, SoftmaxPolicyLinear softmaxPolicyLinear2, SoftmaxPolicyLinear softmaxPolicyLinear3, BaseFeatureSet[] baseFeatureSetArr, FeatureSetExpander featureSetExpander, Optimiser[] optimiserArr, ObjectiveParams objectiveParams, FeatureDiscoveryParams featureDiscoveryParams, TrainingParams trainingParams, PrintWriter printWriter, InterruptableExperiment interruptableExperiment) {
        BaseFeatureSet[] baseFeatureSetArr2 = baseFeatureSetArr;
        int count = game2.players().count();
        ExponentialMovingAverage[] exponentialMovingAverageArr = new ExponentialMovingAverage[count + 1];
        ExponentialMovingAverage[] exponentialMovingAverageArr2 = new ExponentialMovingAverage[count + 1];
        for (int i = 1; i <= count; i++) {
            exponentialMovingAverageArr[i] = new ExponentialMovingAverage();
            exponentialMovingAverageArr2[i] = new ExponentialMovingAverage();
        }
        TLongArrayList[] tLongArrayListArr = new TLongArrayList[baseFeatureSetArr2.length];
        TDoubleArrayList[] tDoubleArrayListArr = new TDoubleArrayList[baseFeatureSetArr2.length];
        for (int i2 = 0; i2 < baseFeatureSetArr2.length; i2++) {
            if (baseFeatureSetArr2[i2] != null) {
                TLongArrayList tLongArrayList = new TLongArrayList();
                tLongArrayList.fill(0, baseFeatureSetArr2[i2].getNumSpatialFeatures(), 0L);
                tLongArrayListArr[i2] = tLongArrayList;
                TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
                tDoubleArrayList.fill(0, baseFeatureSetArr2[i2].getNumSpatialFeatures(), 0.0d);
                tDoubleArrayListArr[i2] = tDoubleArrayList;
            }
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(trainingParams.numPolicyGradientThreads, DaemonThreadFactory.INSTANCE);
        for (int i3 = 0; i3 < trainingParams.numPolicyGradientEpochs && !interruptableExperiment.wantsInterrupt(); i3++) {
            List[] listArr = new List[count + 1];
            for (int i4 = 1; i4 <= count; i4++) {
                listArr[i4] = new ArrayList();
            }
            softmaxPolicyLinear2.initAI(game2, -1);
            CountDownLatch countDownLatch = new CountDownLatch(trainingParams.numPolicyGradientThreads);
            AtomicInteger atomicInteger = new AtomicInteger(0);
            BaseFeatureSet[] baseFeatureSetArr3 = baseFeatureSetArr2;
            for (int i5 = 0; i5 < trainingParams.numPolicyGradientThreads; i5++) {
                newFixedThreadPool.submit(() -> {
                    while (atomicInteger.getAndIncrement() < trainingParams.numTrialsPerPolicyGradientEpoch) {
                        try {
                            try {
                                List[] listArr2 = new List[count + 1];
                                List[] listArr3 = new List[count + 1];
                                List[] listArr4 = new List[count + 1];
                                List[] listArr5 = new List[count + 1];
                                TIntArrayList[] tIntArrayListArr = new TIntArrayList[count + 1];
                                for (int i6 = 1; i6 <= count; i6++) {
                                    listArr2[i6] = new ArrayList();
                                    listArr3[i6] = new ArrayList();
                                    listArr4[i6] = new ArrayList();
                                    listArr5[i6] = new ArrayList();
                                    tIntArrayListArr[i6] = new TIntArrayList();
                                }
                                Trial trial = new Trial(game2);
                                Context context = new Context(game2, trial);
                                game2.start(context);
                                while (!trial.over()) {
                                    int mover = context.state().mover();
                                    FastArrayList<Move> moves = game2.moves(context).moves();
                                    FeatureVector[] computeFeatureVectors = baseFeatureSetArr3[mover].computeFeatureVectors(context, moves, false);
                                    for (FeatureVector featureVector : computeFeatureVectors) {
                                        featureVector.activeSpatialFeatureIndices().trimToSize();
                                    }
                                    int selectActionFromDistribution = softmaxPolicyLinear2.selectActionFromDistribution(softmaxPolicyLinear2.computeDistribution(computeFeatureVectors, mover));
                                    Move move = moves.get(selectActionFromDistribution);
                                    listArr2[mover].add(new Context(context).state());
                                    listArr3[mover].add(context.trial().lastMove());
                                    listArr4[mover].add(new FastArrayList(moves));
                                    listArr5[mover].add(computeFeatureVectors);
                                    tIntArrayListArr[mover].add(selectActionFromDistribution);
                                    game2.apply(context, move);
                                    updateFeatureActivityData(computeFeatureVectors, tLongArrayListArr, tDoubleArrayListArr, mover);
                                }
                                addTrialData(listArr, count, listArr2, listArr3, listArr4, listArr5, tIntArrayListArr, RankUtils.agentUtilities(context), exponentialMovingAverageArr, exponentialMovingAverageArr2, trainingParams);
                            } catch (Exception e) {
                                e.printStackTrace();
                                countDownLatch.countDown();
                                return;
                            }
                        } finally {
                            countDownLatch.countDown();
                        }
                    }
                });
            }
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            softmaxPolicyLinear2.closeAI();
            for (int i6 = 1; i6 <= count; i6++) {
                List list = listArr[i6];
                int size = list.size();
                FVector fVector = new FVector(softmaxPolicyLinear2.linearFunction(i6).trainableParams().allWeights().dim());
                double movingAvg = exponentialMovingAverageArr2[i6].movingAvg();
                for (int i7 = 0; i7 < size; i7++) {
                    PGExperience pGExperience = (PGExperience) list.get(i7);
                    FVector computePolicyGradients = computePolicyGradients(pGExperience, fVector.dim(), movingAvg, trainingParams.entropyRegWeight, softmaxPolicyLinear2.computeDistribution(pGExperience.featureVectors, i6).get(pGExperience.movePlayedIdx()));
                    computePolicyGradients.div(size);
                    fVector.add(computePolicyGradients);
                }
                optimiserArr[i6].maximiseObjective(softmaxPolicyLinear2.linearFunction(i6).trainableParams().allWeights(), fVector);
            }
            if (!featureDiscoveryParams.noGrowFeatureSet && (i3 + 1) % 5 == 0) {
                BaseFeatureSet[] baseFeatureSetArr4 = new BaseFeatureSet[count + 1];
                ExecutorService newFixedThreadPool2 = Executors.newFixedThreadPool(featureDiscoveryParams.numFeatureDiscoveryThreads);
                CountDownLatch countDownLatch2 = new CountDownLatch(count);
                for (int i8 = 1; i8 <= count; i8++) {
                    int i9 = i8;
                    BaseFeatureSet baseFeatureSet = baseFeatureSetArr2[i9];
                    newFixedThreadPool2.submit(() -> {
                        try {
                            try {
                                int i10 = trainingParams.batchSize;
                                ArrayList arrayList = new ArrayList(i10);
                                while (arrayList.size() < i10 && !listArr[i9].isEmpty()) {
                                    int nextInt = ThreadLocalRandom.current().nextInt(listArr[i9].size());
                                    arrayList.add((PGExperience) listArr[i9].get(nextInt));
                                    ListUtils.removeSwap(listArr[i9], nextInt);
                                }
                                if (arrayList.size() > 0) {
                                    long currentTimeMillis = System.currentTimeMillis();
                                    BaseFeatureSet expandFeatureSet = featureSetExpander.expandFeatureSet(arrayList, baseFeatureSet, softmaxPolicyLinear2, game2, featureDiscoveryParams.combiningFeatureInstanceThreshold, objectiveParams, featureDiscoveryParams, tDoubleArrayListArr[i9], printWriter, interruptableExperiment);
                                    if (expandFeatureSet != null) {
                                        baseFeatureSetArr4[i9] = expandFeatureSet;
                                        expandFeatureSet.init(game2, new int[]{i9}, null);
                                        while (tDoubleArrayListArr[i9].size() < expandFeatureSet.getNumSpatialFeatures()) {
                                            tLongArrayListArr[i9].add(0L);
                                            tDoubleArrayListArr[i9].add(0.0d);
                                        }
                                    } else {
                                        baseFeatureSetArr4[i9] = baseFeatureSet;
                                    }
                                    JITSPatterNetFeatureSet.clearFeatureSetCache();
                                    interruptableExperiment.logLine(printWriter, "Expanded feature set in " + (System.currentTimeMillis() - currentTimeMillis) + " ms for P" + i9 + ".");
                                } else {
                                    baseFeatureSetArr4[i9] = baseFeatureSet;
                                }
                                countDownLatch2.countDown();
                            } catch (Exception e2) {
                                e2.printStackTrace();
                                countDownLatch2.countDown();
                            }
                        } catch (Throwable th) {
                            countDownLatch2.countDown();
                            throw th;
                        }
                    });
                }
                try {
                    countDownLatch2.await();
                } catch (InterruptedException e2) {
                    e2.printStackTrace();
                }
                newFixedThreadPool2.shutdown();
                softmaxPolicyLinear.updateFeatureSets(baseFeatureSetArr4);
                softmaxPolicyLinear2.updateFeatureSets(baseFeatureSetArr4);
                softmaxPolicyLinear3.updateFeatureSets(baseFeatureSetArr4);
                baseFeatureSetArr2 = baseFeatureSetArr4;
            }
        }
        newFixedThreadPool.shutdownNow();
        return baseFeatureSetArr2;
    }

    private static FVector computePolicyGradients(PGExperience pGExperience, int i, double d, double d2, float f) {
        FeatureVector[] featureVectors = pGExperience.featureVectors();
        FVector fVector = new FVector(i);
        FVector fVector2 = new FVector(i);
        for (int i2 = 0; i2 < featureVectors.length; i2++) {
            FeatureVector featureVector = featureVectors[i2];
            FVector aspatialFeatureValues = featureVector.aspatialFeatureValues();
            int dim = aspatialFeatureValues.dim();
            for (int i3 = 0; i3 < dim; i3++) {
                fVector.addToEntry(i3, aspatialFeatureValues.get(i3));
            }
            if (i2 == pGExperience.movePlayedIdx()) {
                for (int i4 = 0; i4 < dim; i4++) {
                    fVector2.addToEntry(i4, aspatialFeatureValues.get(i4));
                }
            }
            TIntArrayList activeSpatialFeatureIndices = featureVector.activeSpatialFeatureIndices();
            for (int i5 = 0; i5 < activeSpatialFeatureIndices.size(); i5++) {
                fVector.addToEntry(activeSpatialFeatureIndices.getQuick(i5) + dim, 1.0f);
            }
            if (i2 == pGExperience.movePlayedIdx()) {
                for (int i6 = 0; i6 < activeSpatialFeatureIndices.size(); i6++) {
                    fVector2.addToEntry(activeSpatialFeatureIndices.getQuick(i6) + dim, 1.0f);
                }
            }
        }
        fVector.div(featureVectors.length);
        fVector2.subtract(fVector);
        fVector2.mult((float) ((pGExperience.discountMultiplier() * (pGExperience.returns() - d)) - (d2 * Math.log(f))));
        return fVector2;
    }

    private static synchronized void updateFeatureActivityData(FeatureVector[] featureVectorArr, TLongArrayList[] tLongArrayListArr, TDoubleArrayList[] tDoubleArrayListArr, int i) {
        for (FeatureVector featureVector : featureVectorArr) {
            TIntArrayList activeSpatialFeatureIndices = featureVector.activeSpatialFeatureIndices();
            if (!activeSpatialFeatureIndices.isEmpty()) {
                activeSpatialFeatureIndices.sort();
                tLongArrayListArr[i].transformValues(j -> {
                    return j + 1;
                });
                TDoubleArrayList tDoubleArrayList = tDoubleArrayListArr[i];
                int i2 = 0;
                for (int i3 = 0; i3 < tDoubleArrayList.size(); i3++) {
                    double quick = tDoubleArrayList.getQuick(i3);
                    if (i2 >= activeSpatialFeatureIndices.size() || activeSpatialFeatureIndices.getQuick(i2) != i3) {
                        tDoubleArrayList.setQuick(i3, quick + ((0.0d - quick) / tLongArrayListArr[i].getQuick(i3)));
                    } else {
                        tDoubleArrayList.setQuick(i3, quick + ((1.0d - quick) / tLongArrayListArr[i].getQuick(i3)));
                        i2++;
                    }
                }
                if (i2 != activeSpatialFeatureIndices.size()) {
                    System.err.println("ERROR: expected vectorIdx == sparse.size()!");
                    System.err.println("vectorIdx = " + i2);
                    System.err.println("sparse.size() = " + activeSpatialFeatureIndices.size());
                    System.err.println("sparse = " + activeSpatialFeatureIndices);
                }
            }
        }
    }

    private static synchronized void addTrialData(List<PGExperience>[] listArr, int i, List<State>[] listArr2, List<Move>[] listArr3, List<FastArrayList<Move>>[] listArr4, List<FeatureVector[]>[] listArr5, TIntArrayList[] tIntArrayListArr, double[] dArr, ExponentialMovingAverage[] exponentialMovingAverageArr, ExponentialMovingAverage[] exponentialMovingAverageArr2, TrainingParams trainingParams) {
        for (int i2 = 1; i2 <= i; i2++) {
            List<State> list = listArr2[i2];
            List<Move> list2 = listArr3[i2];
            List<FastArrayList<Move>> list3 = listArr4[i2];
            List<FeatureVector[]> list4 = listArr5[i2];
            TIntArrayList tIntArrayList = tIntArrayListArr[i2];
            exponentialMovingAverageArr[i2].observe(list.size());
            exponentialMovingAverageArr2[i2].observe(dArr[i2]);
            double d = 1.0d;
            boolean[] zArr = new boolean[list4.size()];
            if (trainingParams.pgGamma == 1.0d) {
                int i3 = 0;
                while (zArr.length - i3 > 50) {
                    int nextInt = ThreadLocalRandom.current().nextInt(zArr.length);
                    if (!zArr[nextInt]) {
                        zArr[nextInt] = true;
                        i3++;
                    }
                }
            }
            for (int size = list4.size() - 1; size >= 0; size--) {
                if (!zArr[size] && list3.get(size).size() > 1) {
                    listArr[i2].add(new PGExperience(list.get(size), list2.get(size), list3.get(size), list4.get(size), tIntArrayList.getQuick(size), (float) dArr[i2], d));
                }
                d *= trainingParams.pgGamma;
                if (d < EXPERIENCE_DISCOUNT_THRESHOLD) {
                    break;
                }
            }
        }
    }
}
