package decision_trees.classifiers;

import com.itextpdf.text.pdf.ColumnText;
import features.FeatureVector;
import features.WeightVector;
import features.aspatial.AspatialFeature;
import features.feature_sets.BaseFeatureSet;
import function_approx.LinearFunction;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Comparator;
import java.util.List;
import main.collections.ArrayUtils;
import main.math.MathRoutines;
import training.expert_iteration.ExItExperience;
import utils.data_structures.experience_buffers.ExperienceBuffer;

/* loaded from: input_file:decision_trees/classifiers/ExperienceIQRTreeLearner.class */
public class ExperienceIQRTreeLearner {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:decision_trees/classifiers/ExperienceIQRTreeLearner$IQRClass.class */
    public enum IQRClass {
        UNDEFINED,
        Bottom25,
        IQR,
        Top25
    }

    public static DecisionTreeNode buildTree(BaseFeatureSet baseFeatureSet, LinearFunction linearFunction, ExperienceBuffer experienceBuffer, int i, int i2) {
        WeightVector effectiveParams = linearFunction.effectiveParams();
        ExItExperience[] allExperience = experienceBuffer.allExperience();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (ExItExperience exItExperience : allExperience) {
            if (exItExperience != null && exItExperience.moves().size() > 1) {
                FeatureVector[] generateFeatureVectors = exItExperience.generateFeatureVectors(baseFeatureSet);
                final float[] fArr = new float[generateFeatureVectors.length];
                for (int i3 = 0; i3 < generateFeatureVectors.length; i3++) {
                    fArr[i3] = effectiveParams.dot(generateFeatureVectors[i3]);
                }
                float max = ArrayUtils.max(fArr);
                float min = ArrayUtils.min(fArr);
                if (max != min) {
                    for (FeatureVector featureVector : generateFeatureVectors) {
                        arrayList.add(featureVector);
                    }
                    int nextSetBit = exItExperience.winningMoves().nextSetBit(0);
                    while (true) {
                        int i4 = nextSetBit;
                        if (i4 < 0) {
                            break;
                        }
                        fArr[i4] = max;
                        nextSetBit = exItExperience.winningMoves().nextSetBit(i4 + 1);
                    }
                    int nextSetBit2 = exItExperience.losingMoves().nextSetBit(0);
                    while (true) {
                        int i5 = nextSetBit2;
                        if (i5 < 0) {
                            break;
                        }
                        fArr[i5] = min;
                        nextSetBit2 = exItExperience.losingMoves().nextSetBit(i5 + 1);
                    }
                    List<Integer> sortedIndices = ArrayUtils.sortedIndices(generateFeatureVectors.length, new Comparator<Integer>() { // from class: decision_trees.classifiers.ExperienceIQRTreeLearner.1
                        @Override // java.util.Comparator
                        public int compare(Integer num, Integer num2) {
                            float f = fArr[num.intValue()] - fArr[num2.intValue()];
                            if (f < ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                                return -1;
                            }
                            return f > ColumnText.GLOBAL_SPACE_CHAR_RATIO ? 1 : 0;
                        }
                    });
                    int min2 = (int) Math.min(1L, Math.round(0.25d * generateFeatureVectors.length));
                    int length = (generateFeatureVectors.length - min2) - min2;
                    float f = Float.POSITIVE_INFINITY;
                    float f2 = Float.NEGATIVE_INFINITY;
                    IQRClass[] iQRClassArr = new IQRClass[sortedIndices.size()];
                    for (int i6 = 0; i6 < min2; i6++) {
                        float f3 = fArr[sortedIndices.get(i6).intValue()];
                        iQRClassArr[sortedIndices.get(i6).intValue()] = IQRClass.Bottom25;
                        f2 = Math.max(f2, f3);
                    }
                    for (int size = sortedIndices.size() - 1; size >= min2 + length; size--) {
                        float f4 = fArr[sortedIndices.get(size).intValue()];
                        iQRClassArr[sortedIndices.get(size).intValue()] = IQRClass.Top25;
                        f = Math.min(f, f4);
                    }
                    for (int i7 = min2; i7 < min2 + length; i7++) {
                        float f5 = fArr[sortedIndices.get(i7).intValue()];
                        if (f5 == f) {
                            iQRClassArr[sortedIndices.get(i7).intValue()] = IQRClass.Top25;
                        } else if (f5 == f2) {
                            iQRClassArr[sortedIndices.get(i7).intValue()] = IQRClass.Bottom25;
                        } else {
                            iQRClassArr[sortedIndices.get(i7).intValue()] = IQRClass.IQR;
                        }
                    }
                    if (f == f2) {
                        for (int i8 = 0; i8 < sortedIndices.size(); i8++) {
                            if (fArr[sortedIndices.get(i8).intValue()] == f) {
                                iQRClassArr[sortedIndices.get(i8).intValue()] = IQRClass.IQR;
                            }
                        }
                    }
                    for (IQRClass iQRClass : iQRClassArr) {
                        arrayList2.add(iQRClass);
                    }
                }
            }
        }
        return buildNode(baseFeatureSet, arrayList, arrayList2, new BitSet(), new BitSet(), baseFeatureSet.getNumAspatialFeatures(), baseFeatureSet.getNumSpatialFeatures(), i, i2);
    }

    private static DecisionTreeNode buildNode(BaseFeatureSet baseFeatureSet, List<FeatureVector> list, List<IQRClass> list2, BitSet bitSet, BitSet bitSet2, int i, int i2, int i3, int i4) {
        BitSet bitSet3;
        BitSet bitSet4;
        if (i4 <= 0) {
            throw new IllegalArgumentException("minSamplesPerLeaf must be greater than 0");
        }
        if (list.isEmpty()) {
            System.err.println("Empty list of remaining feature vectors!");
            return new DecisionLeafNode(0.33333334f, 0.33333334f, 0.33333334f);
        }
        int i5 = 0;
        int i6 = 0;
        for (IQRClass iQRClass : list2) {
            if (iQRClass == IQRClass.Bottom25) {
                i5++;
            } else if (iQRClass == IQRClass.Top25) {
                i6++;
            }
        }
        float size = i5 / list2.size();
        float size2 = i6 / list2.size();
        float f = (1.0f - size) - size2;
        if (i3 == 0) {
            return new DecisionLeafNode(size, f, size2);
        }
        double log2 = size > ColumnText.GLOBAL_SPACE_CHAR_RATIO ? 0.0d - (size * MathRoutines.log2(size)) : 0.0d;
        if (size2 > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
            log2 -= size2 * MathRoutines.log2(size2);
        }
        if (f > ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
            log2 -= f * MathRoutines.log2(f);
        }
        double d = Double.NEGATIVE_INFINITY;
        double d2 = Double.POSITIVE_INFINITY;
        int i7 = -1;
        boolean z = true;
        for (int i8 = 0; i8 < i; i8++) {
            if (!bitSet.get(i8)) {
                int i9 = 0;
                int i10 = 0;
                int i11 = 0;
                int i12 = 0;
                int i13 = 0;
                int i14 = 0;
                for (int i15 = 0; i15 < list.size(); i15++) {
                    FeatureVector featureVector = list.get(i15);
                    IQRClass iQRClass2 = list2.get(i15);
                    if (featureVector.aspatialFeatureValues().get(i8) != ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                        switch (iQRClass2) {
                            case Bottom25:
                                i12++;
                                break;
                            case IQR:
                                i13++;
                                break;
                            case Top25:
                                i14++;
                                break;
                            default:
                                System.err.println("Unrecognised IQR class!");
                                break;
                        }
                    } else {
                        switch (iQRClass2) {
                            case Bottom25:
                                i9++;
                                break;
                            case IQR:
                                i10++;
                                break;
                            case Top25:
                                i11++;
                                break;
                            default:
                                System.err.println("Unrecognised IQR class!");
                                break;
                        }
                    }
                }
                int i16 = i9 + i10 + i11;
                int i17 = i12 + i13 + i14;
                if (i16 >= i4 && i17 >= i4) {
                    double d3 = i9 / i16;
                    double d4 = i10 / i16;
                    double d5 = i11 / i16;
                    double d6 = i12 / i17;
                    double d7 = i13 / i17;
                    double d8 = i14 / i17;
                    double log22 = d3 > 0.0d ? 0.0d - (d3 * MathRoutines.log2(d3)) : 0.0d;
                    if (d4 > 0.0d) {
                        log22 -= d4 * MathRoutines.log2(d4);
                    }
                    if (d5 > 0.0d) {
                        log22 -= d5 * MathRoutines.log2(d5);
                    }
                    double log23 = d6 > 0.0d ? 0.0d - (d6 * MathRoutines.log2(d6)) : 0.0d;
                    if (d7 > 0.0d) {
                        log23 -= d7 * MathRoutines.log2(d7);
                    }
                    if (d8 > 0.0d) {
                        log23 -= d8 * MathRoutines.log2(d8);
                    }
                    double d9 = i16 / (i16 + i17);
                    double d10 = (log2 - (d9 * log22)) - ((1.0d - d9) * log23);
                    if (d10 > d) {
                        d = d10;
                        i7 = i8;
                    }
                    if (d10 < d2) {
                        d2 = d10;
                    }
                }
            }
        }
        for (int i18 = 0; i18 < i2; i18++) {
            if (!bitSet2.get(i18)) {
                int i19 = 0;
                int i20 = 0;
                int i21 = 0;
                int i22 = 0;
                int i23 = 0;
                int i24 = 0;
                for (int i25 = 0; i25 < list.size(); i25++) {
                    FeatureVector featureVector2 = list.get(i25);
                    IQRClass iQRClass3 = list2.get(i25);
                    if (featureVector2.activeSpatialFeatureIndices().contains(i18)) {
                        switch (iQRClass3) {
                            case Bottom25:
                                i22++;
                                break;
                            case IQR:
                                i23++;
                                break;
                            case Top25:
                                i24++;
                                break;
                            default:
                                System.err.println("Unrecognised IQR class!");
                                break;
                        }
                    } else {
                        switch (iQRClass3) {
                            case Bottom25:
                                i19++;
                                break;
                            case IQR:
                                i20++;
                                break;
                            case Top25:
                                i21++;
                                break;
                            default:
                                System.err.println("Unrecognised IQR class!");
                                break;
                        }
                    }
                }
                int i26 = i19 + i20 + i21;
                int i27 = i22 + i23 + i24;
                if (i26 >= i4 && i27 >= i4) {
                    double d11 = i19 / i26;
                    double d12 = i20 / i26;
                    double d13 = i21 / i26;
                    double d14 = i22 / i27;
                    double d15 = i23 / i27;
                    double d16 = i24 / i27;
                    double log24 = d11 > 0.0d ? 0.0d - (d11 * MathRoutines.log2(d11)) : 0.0d;
                    if (d12 > 0.0d) {
                        log24 -= d12 * MathRoutines.log2(d12);
                    }
                    if (d13 > 0.0d) {
                        log24 -= d13 * MathRoutines.log2(d13);
                    }
                    double log25 = d14 > 0.0d ? 0.0d - (d14 * MathRoutines.log2(d14)) : 0.0d;
                    if (d15 > 0.0d) {
                        log25 -= d15 * MathRoutines.log2(d15);
                    }
                    if (d16 > 0.0d) {
                        log25 -= d16 * MathRoutines.log2(d16);
                    }
                    double d17 = i26 / (i26 + i27);
                    double d18 = (log2 - (d17 * log24)) - ((1.0d - d17) * log25);
                    if (d18 > d) {
                        d = d18;
                        i7 = i18;
                        z = false;
                    }
                    if (d18 < d2) {
                        d2 = d18;
                    }
                }
            }
        }
        if (i7 == -1 || d == 0.0d || d2 == d) {
            return new DecisionLeafNode(size, f, size2);
        }
        AspatialFeature aspatialFeature = z ? baseFeatureSet.aspatialFeatures()[i7] : baseFeatureSet.spatialFeatures()[i7];
        if (z) {
            bitSet4 = (BitSet) bitSet.clone();
            bitSet4.set(i7);
            bitSet3 = bitSet2;
        } else {
            bitSet3 = (BitSet) bitSet2.clone();
            bitSet3.set(i7);
            bitSet4 = bitSet;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        if (z) {
            for (int i28 = 0; i28 < list.size(); i28++) {
                if (list.get(i28).aspatialFeatureValues().get(i7) != ColumnText.GLOBAL_SPACE_CHAR_RATIO) {
                    arrayList.add(list.get(i28));
                    arrayList2.add(list2.get(i28));
                } else {
                    arrayList3.add(list.get(i28));
                    arrayList4.add(list2.get(i28));
                }
            }
        } else {
            for (int i29 = 0; i29 < list.size(); i29++) {
                if (list.get(i29).activeSpatialFeatureIndices().contains(i7)) {
                    arrayList.add(list.get(i29));
                    arrayList2.add(list2.get(i29));
                } else {
                    arrayList3.add(list.get(i29));
                    arrayList4.add(list2.get(i29));
                }
            }
        }
        return new DecisionConditionNode(aspatialFeature, buildNode(baseFeatureSet, arrayList, arrayList2, bitSet4, bitSet3, i, i2, i3 - 1, i4), buildNode(baseFeatureSet, arrayList3, arrayList4, bitSet4, bitSet3, i, i2, i3 - 1, i4));
    }
}
