package optimisers;

import com.itextpdf.text.pdf.ColumnText;
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import main.collections.FVector;
import org.apache.batik.svggen.SVGSyntax;

/* loaded from: input_file:optimisers/AMSGrad.class */
public class AMSGrad extends Optimiser {
    private static final long serialVersionUID = 1;
    protected final float beta1;
    protected final float beta2;
    protected final float epsilon;
    private FVector movingAvgGradients;
    private FVector movingAvgSquaredGradients;
    private FVector maxMovingAvgSquaredGradients;

    public AMSGrad(float f) {
        super(f);
        this.movingAvgGradients = null;
        this.movingAvgSquaredGradients = null;
        this.maxMovingAvgSquaredGradients = null;
        this.beta1 = 0.9f;
        this.beta2 = 0.999f;
        this.epsilon = 1.0E-8f;
    }

    public AMSGrad(float f, float f2, float f3, float f4) {
        super(f);
        this.movingAvgGradients = null;
        this.movingAvgSquaredGradients = null;
        this.maxMovingAvgSquaredGradients = null;
        this.beta1 = f2;
        this.beta2 = f3;
        this.epsilon = f4;
    }

    @Override // optimisers.Optimiser
    public void maximiseObjective(FVector fVector, FVector fVector2) {
        if (this.movingAvgGradients == null) {
            this.movingAvgGradients = new FVector(fVector2.dim());
            this.movingAvgSquaredGradients = new FVector(fVector2.dim());
            this.maxMovingAvgSquaredGradients = new FVector(fVector2.dim());
        } else {
            while (this.movingAvgGradients.dim() < fVector2.dim()) {
                this.movingAvgGradients = this.movingAvgGradients.append(ColumnText.GLOBAL_SPACE_CHAR_RATIO);
                this.movingAvgSquaredGradients = this.movingAvgSquaredGradients.append(ColumnText.GLOBAL_SPACE_CHAR_RATIO);
                this.maxMovingAvgSquaredGradients = this.maxMovingAvgSquaredGradients.append(ColumnText.GLOBAL_SPACE_CHAR_RATIO);
            }
        }
        this.movingAvgGradients.mult(this.beta1);
        this.movingAvgGradients.addScaled(fVector2, 1.0f - this.beta1);
        FVector copy = fVector2.copy();
        copy.hadamardProduct(copy);
        this.movingAvgSquaredGradients.mult(this.beta2);
        this.movingAvgSquaredGradients.addScaled(copy, 1.0f - this.beta2);
        this.maxMovingAvgSquaredGradients = FVector.elementwiseMax(this.maxMovingAvgSquaredGradients, this.movingAvgSquaredGradients);
        FVector copy2 = this.movingAvgGradients.copy();
        copy2.mult(this.baseStepSize / (1.0f - this.beta1));
        FVector copy3 = this.maxMovingAvgSquaredGradients.copy();
        copy3.div(1.0f - this.beta2);
        copy3.sqrt();
        copy3.add(this.epsilon);
        copy2.elementwiseDivision(copy3);
        fVector.add(copy2);
    }

    public static AMSGrad fromLines(String[] strArr) {
        float f = 3.0E-4f;
        float f2 = 0.9f;
        float f3 = 0.999f;
        float f4 = 1.0E-8f;
        for (String str : strArr) {
            String[] split = str.split(SVGSyntax.COMMA);
            if (split[0].toLowerCase().startsWith("basestepsize=")) {
                f = Float.parseFloat(split[0].substring("basestepsize=".length()));
            } else if (split[0].toLowerCase().startsWith("beta1=")) {
                f2 = Float.parseFloat(split[0].substring("beta1=".length()));
            } else if (split[0].toLowerCase().startsWith("beta2=")) {
                f3 = Float.parseFloat(split[0].substring("beta2=".length()));
            } else if (split[0].toLowerCase().startsWith("epsilon=")) {
                f4 = Float.parseFloat(split[0].substring("epsilon=".length()));
            }
        }
        return new AMSGrad(f, f2, f3, f4);
    }

    @Override // optimisers.Optimiser
    public void writeToFile(String str) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(str)));
            try {
                objectOutputStream.writeObject(this);
                objectOutputStream.flush();
                objectOutputStream.close();
                objectOutputStream.close();
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
