package utils.data_structures.experience_buffers;

import game.Game;
import game.equipment.container.Container;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import other.state.container.ContainerState;
import training.expert_iteration.ExItExperience;

/* loaded from: input_file:utils/data_structures/experience_buffers/PrioritizedReplayBuffer.class */
public class PrioritizedReplayBuffer implements Serializable, ExperienceBuffer {
    private static final long serialVersionUID = 1;
    protected final int replayCapacity;
    protected final SumTree sumTree;
    protected final ExItExperience[] buffer;
    protected long addCount;
    protected final double alpha;
    protected final double beta;
    static final /* synthetic */ boolean $assertionsDisabled;

    public PrioritizedReplayBuffer(int i) {
        this(i, 0.5d, 0.5d);
    }

    public PrioritizedReplayBuffer(int i, double d, double d2) {
        this.replayCapacity = i;
        this.sumTree = new SumTree(i);
        this.buffer = new ExItExperience[i];
        this.addCount = 0L;
        this.alpha = d;
        this.beta = d2;
    }

    @Override // utils.data_structures.experience_buffers.ExperienceBuffer
    public void add(ExItExperience exItExperience) {
        this.sumTree.set(cursor(), this.sumTree.maxRecordedPriority());
        this.buffer[cursor()] = exItExperience;
        this.addCount++;
    }

    public void add(ExItExperience exItExperience, float f) {
        this.sumTree.set(cursor(), (float) Math.pow(f, this.alpha));
        this.buffer[cursor()] = exItExperience;
        this.addCount++;
    }

    public float[] getPriorities(int[] iArr) {
        float[] fArr = new float[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            fArr[i] = this.sumTree.get(iArr[i]);
        }
        return fArr;
    }

    public boolean isEmpty() {
        return this.addCount == 0;
    }

    public boolean isFull() {
        return this.addCount >= ((long) this.replayCapacity);
    }

    public int size() {
        return isFull() ? this.replayCapacity : (int) this.addCount;
    }

    public int[] sampleIndexBatch(int i) {
        return this.sumTree.stratifiedSample(i);
    }

    @Override // utils.data_structures.experience_buffers.ExperienceBuffer
    public List<ExItExperience> sampleExperienceBatch(int i) {
        int min = (int) Math.min(i, this.addCount);
        ArrayList arrayList = new ArrayList(min);
        int[] sampleIndexBatch = sampleIndexBatch(min);
        double[] dArr = new double[i];
        double d = Double.NEGATIVE_INFINITY;
        int min2 = Math.min(this.replayCapacity, (int) this.addCount) - 1;
        for (int i2 = 0; i2 < min; i2++) {
            if (sampleIndexBatch[i2] > min2) {
                sampleIndexBatch[i2] = min2;
            }
        }
        float[] priorities = getPriorities(sampleIndexBatch);
        for (int i3 = 0; i3 < min; i3++) {
            arrayList.add(this.buffer[sampleIndexBatch[i3]]);
            dArr[i3] = Math.pow((1.0d / size()) * (1.0d / (priorities[i3] / this.sumTree.totalPriority())), this.beta);
            d = Math.max(d, dArr[i3]);
        }
        for (int i4 = 0; i4 < min; i4++) {
            ((ExItExperience) arrayList.get(i4)).setWeightPER((float) (dArr[i4] / d));
            ((ExItExperience) arrayList.get(i4)).setBufferIdx(sampleIndexBatch[i4]);
        }
        return arrayList;
    }

    @Override // utils.data_structures.experience_buffers.ExperienceBuffer
    public List<ExItExperience> sampleExperienceBatchUniformly(int i) {
        int min = (int) Math.min(i, this.addCount);
        ArrayList arrayList = new ArrayList(min);
        int size = size();
        for (int i2 = 0; i2 < min; i2++) {
            arrayList.add(this.buffer[ThreadLocalRandom.current().nextInt(size)]);
        }
        return arrayList;
    }

    @Override // utils.data_structures.experience_buffers.ExperienceBuffer
    public ExItExperience[] allExperience() {
        return this.buffer;
    }

    public void setPriorities(int[] iArr, float[] fArr) {
        if (!$assertionsDisabled && iArr.length != fArr.length) {
            throw new AssertionError();
        }
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= 0) {
                this.sumTree.set(iArr[i], (float) Math.pow(fArr[i], this.alpha));
            }
        }
    }

    public SumTree sumTree() {
        return this.sumTree;
    }

    public double alpha() {
        return this.alpha;
    }

    public double beta() {
        return this.beta;
    }

    public long addCount() {
        return this.addCount;
    }

    public int cursor() {
        return (int) (this.addCount % this.replayCapacity);
    }

    public static PrioritizedReplayBuffer fromFile(Game game2, String str) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(str)));
            try {
                PrioritizedReplayBuffer prioritizedReplayBuffer = (PrioritizedReplayBuffer) objectInputStream.readObject();
                for (ExItExperience exItExperience : prioritizedReplayBuffer.buffer) {
                    if (exItExperience != null) {
                        for (ContainerState containerState : exItExperience.state().state().containerStates()) {
                            if (containerState != null) {
                                String nameFromFile = containerState.nameFromFile();
                                Container[] containers = game2.equipment().containers();
                                int length = containers.length;
                                int i = 0;
                                while (true) {
                                    if (i >= length) {
                                        break;
                                    }
                                    Container container = containers[i];
                                    if (container != null && container.name().equals(nameFromFile)) {
                                        containerState.setContainer(container);
                                        break;
                                    }
                                    i++;
                                }
                            }
                        }
                    }
                }
                objectInputStream.close();
                return prioritizedReplayBuffer;
            } finally {
            }
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override // utils.data_structures.experience_buffers.ExperienceBuffer
    public void writeToFile(String str) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(str)));
            try {
                objectOutputStream.writeObject(this);
                objectOutputStream.flush();
                objectOutputStream.close();
                objectOutputStream.close();
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    static {
        $assertionsDisabled = !PrioritizedReplayBuffer.class.desiredAssertionStatus();
    }
}
