package policies;

import features.FeatureSet;
import function_approx.BoostedLinearFunction;
import function_approx.LinearFunction;
import game.Game;
import game.rules.play.moves.Moves;
import gnu.trove.list.array.TIntArrayList;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import main.collections.FVector;
import main.collections.FastArrayList;
import org.apache.batik.util.SVGConstants;
import util.Context;
import util.Move;
import util.Trial;
import utils.ExperimentFileUtils;

/* loaded from: input_file:policies/GreedyPolicy.class */
public class GreedyPolicy extends Policy {
    protected LinearFunction[] linearFunctions;
    protected FeatureSet[] featureSets;
    protected int playoutTurnLimit;

    public GreedyPolicy() {
        this.playoutTurnLimit = 200;
        this.linearFunctions = null;
        this.featureSets = null;
    }

    public GreedyPolicy(LinearFunction[] linearFunctionArr, FeatureSet[] featureSetArr) {
        this.playoutTurnLimit = 200;
        this.linearFunctions = linearFunctionArr;
        this.featureSets = featureSetArr;
    }

    @Override // policies.Policy
    public FVector computeDistribution(Context context, FastArrayList<Move> fastArrayList, boolean z) {
        return computeDistribution((this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()]).computeSparseFeatureVectors(context, fastArrayList, z), context.state().mover());
    }

    @Override // policies.Policy
    public float computeLogit(Context context, Move move) {
        FeatureSet featureSet = this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()];
        LinearFunction linearFunction = this.linearFunctions.length == 1 ? this.linearFunctions[0] : this.linearFunctions[context.state().mover()];
        FastArrayList<Move> fastArrayList = new FastArrayList<>(1);
        fastArrayList.add(move);
        return linearFunction.predict(featureSet.computeSparseFeatureVectors(context, fastArrayList, true).get(0));
    }

    public float[] computeLogits(List<TIntArrayList> list, int i) {
        float[] fArr = new float[list.size()];
        LinearFunction linearFunction = this.linearFunctions.length == 1 ? this.linearFunctions[0] : this.linearFunctions[i];
        for (int i2 = 0; i2 < list.size(); i2++) {
            fArr[i2] = linearFunction.predict(list.get(i2));
        }
        return fArr;
    }

    public FVector computeDistribution(List<TIntArrayList> list, int i) {
        float[] computeLogits = computeLogits(list, i);
        float f = Float.NEGATIVE_INFINITY;
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (int i2 = 0; i2 < computeLogits.length; i2++) {
            float f2 = computeLogits[i2];
            if (f2 > f) {
                f = f2;
                tIntArrayList.reset();
                tIntArrayList.add(i2);
            } else if (f2 == f) {
                tIntArrayList.add(i2);
            }
        }
        float size = 1.0f / tIntArrayList.size();
        FVector fVector = new FVector(computeLogits.length);
        for (int i3 = 0; i3 < tIntArrayList.size(); i3++) {
            fVector.set(tIntArrayList.getQuick(i3), size);
        }
        return fVector;
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public Trial runPlayout(Context context) {
        FVector[] fVectorArr = new FVector[this.linearFunctions.length];
        for (int i = 0; i < this.linearFunctions.length; i++) {
            if (this.linearFunctions[i] == null) {
                fVectorArr[i] = null;
            } else {
                fVectorArr[i] = this.linearFunctions[i].effectiveParams();
            }
        }
        return context.game().playout(context, null, 1.0d, this.featureSets, fVectorArr, -1, this.playoutTurnLimit, -1.0f, ThreadLocalRandom.current());
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public boolean playoutSupportsGame(Game game2) {
        return supportsGame(game2);
    }

    @Override // search.mcts.playout.PlayoutStrategy
    public void customise(String[] strArr) {
        String str = null;
        boolean z = false;
        for (int i = 1; i < strArr.length; i++) {
            String str2 = strArr[i];
            if (str2.toLowerCase().startsWith("policyweights=")) {
                str = str2.substring("policyweights=".length());
            } else if (str2.toLowerCase().startsWith("playoutturnlimit=")) {
                this.playoutTurnLimit = Integer.parseInt(str2.substring("playoutturnlimit=".length()));
            } else if (str2.toLowerCase().startsWith("friendly_name=")) {
                this.friendlyName = str2.substring("friendly_name=".length());
            } else if (str2.toLowerCase().startsWith("boosted=") && str2.toLowerCase().endsWith(SVGConstants.SVG_TRUE_VALUE)) {
                z = true;
            }
        }
        if (str == null) {
            System.err.println("Cannot construct Greedy Policy from: " + Arrays.toString(strArr));
            return;
        }
        String parent = new File(str).getParent();
        if (!new File(str).exists()) {
            str = ExperimentFileUtils.getLastFilepath(parent + "/PolicyWeights", "txt");
        }
        if (z) {
            this.linearFunctions = new LinearFunction[]{BoostedLinearFunction.boostedFromFile(str, null)};
        } else {
            this.linearFunctions = new LinearFunction[]{LinearFunction.fromFile(str)};
        }
        this.featureSets = new FeatureSet[this.linearFunctions.length];
        for (int i2 = 0; i2 < this.linearFunctions.length; i2++) {
            if (this.linearFunctions[i2] != null) {
                this.featureSets[i2] = new FeatureSet(parent + File.separator + this.linearFunctions[i2].featureSetFile());
            }
        }
    }

    @Override // util.AI
    public Move selectAction(Game game2, Context context, double d, int i, int i2) {
        Moves moves = game2.moves(context);
        return moves.moves().get(FVector.wrap(computeLogits((this.featureSets.length == 1 ? this.featureSets[0] : this.featureSets[context.state().mover()]).computeSparseFeatureVectors(context, moves.moves(), true), context.state().mover())).argMaxRand());
    }

    public static GreedyPolicy fromLines(String[] strArr) {
        GreedyPolicy greedyPolicy = new GreedyPolicy();
        greedyPolicy.customise(strArr);
        return greedyPolicy;
    }

    @Override // util.AI
    public void initAI(Game game2, int i) {
        if (this.featureSets.length != 1) {
            for (int i2 = 1; i2 < this.featureSets.length; i2++) {
                if (!this.featureSets[i2].hasInstantiatedFeatures(game2, this.linearFunctions[i2].effectiveParams())) {
                    this.featureSets[i2].instantiateFeatures(game2, new int[]{i2}, this.linearFunctions[i2].effectiveParams());
                }
            }
            return;
        }
        int[] iArr = new int[game2.players().count()];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = i3 + 1;
        }
        this.featureSets[0].instantiateFeatures(game2, iArr, this.linearFunctions[0].effectiveParams());
    }
}
