package com.oplus.deepthinker.internal.api.algorithm.randomforest;

import android.util.SparseArray;
import android.util.SparseIntArray;
import com.oplus.deepthinker.internal.api.utils.OplusLog;
import com.oplus.deepthinker.sdk.app.aidl.eventfountain.EventType;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

/* loaded from: classes2.dex */
public class RandomForest implements Serializable {
    private static final int DEFAULT_TREE_NUM = 50;
    private static final int LABEL_NUMBER = 1;
    private static final int MAX_ADJUST_TREE_NUM = 200;
    private static final int MAX_TRAIN_DATA_NUM = 20000;
    private static final int MIN_ADJUST_TREE_NUM = 100;
    private static final String TAG = "RandomForest";
    private static final int TREE_THRESHOLD = 3;
    private DecisionTree[] mForest;
    private int mTreeNum = 50;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes2.dex */
    public class BaggingInstances {

        /* renamed from: a, reason: collision with root package name */
        int[][] f4634a;

        /* renamed from: b, reason: collision with root package name */
        int[] f4635b;
        int[] c;

        BaggingInstances() {
        }
    }

    /* loaded from: classes2.dex */
    private class ValueComparator implements Comparator<Map.Entry<Integer, Integer>> {
        private ValueComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Map.Entry<Integer, Integer> entry, Map.Entry<Integer, Integer> entry2) {
            if (entry == null || entry2 == null) {
                return 0;
            }
            return entry2.getValue().intValue() - entry.getValue().intValue();
        }
    }

    public RandomForest() {
        OplusLog.v(TAG, "RandomForest created");
    }

    private BaggingInstances baggingInstances(int[][] iArr, int[] iArr2) {
        BaggingInstances baggingInstances = new BaggingInstances();
        int length = iArr2.length;
        int min = Math.min(length, MAX_TRAIN_DATA_NUM);
        OplusLog.d(TAG, "data size:" + length + ", fetch num:" + min);
        Random random = new Random();
        baggingInstances.f4634a = new int[min];
        baggingInstances.c = new int[min];
        baggingInstances.f4635b = new int[min];
        for (int i = 0; i < min; i++) {
            int nextInt = random.nextInt(length);
            baggingInstances.f4634a[i] = (int[]) iArr[nextInt].clone();
            baggingInstances.c[i] = iArr2[nextInt];
            baggingInstances.f4635b[i] = nextInt;
        }
        return baggingInstances;
    }

    public void adjustTreeNum(int i) {
        if (i < 100) {
            this.mTreeNum = 100;
        } else if (i > MAX_ADJUST_TREE_NUM) {
            this.mTreeNum = MAX_ADJUST_TREE_NUM;
        } else {
            this.mTreeNum = i;
        }
        OplusLog.v(TAG, "Forest mTreeNum = " + this.mTreeNum);
    }

    public List<Integer> predict(int[] iArr) {
        HashMap hashMap = new HashMap();
        for (DecisionTree decisionTree : this.mForest) {
            int predict = decisionTree.predict(iArr);
            if (predict != -1) {
                if (hashMap.get(Integer.valueOf(predict)) == null) {
                    hashMap.put(Integer.valueOf(predict), 1);
                } else {
                    hashMap.put(Integer.valueOf(predict), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(predict))).intValue() + 1));
                }
            }
        }
        if (hashMap.isEmpty()) {
            return null;
        }
        ValueComparator valueComparator = new ValueComparator();
        ArrayList arrayList = new ArrayList(hashMap.entrySet());
        Collections.sort(arrayList, valueComparator);
        int size = arrayList.size();
        OplusLog.v(TAG, "predictList size is:" + size);
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < size; i++) {
            if (((Integer) ((Map.Entry) arrayList.get(i)).getValue()).intValue() >= 3) {
                arrayList2.add((Integer) ((Map.Entry) arrayList.get(i)).getKey());
            }
        }
        return arrayList2;
    }

    public SparseArray<Float> predicting(int[] iArr) {
        if (iArr == null) {
            return null;
        }
        SparseIntArray sparseIntArray = new SparseIntArray();
        for (DecisionTree decisionTree : this.mForest) {
            int predict = decisionTree.predict(iArr);
            if (predict != -1) {
                if (sparseIntArray.get(predict) < 1) {
                    sparseIntArray.put(predict, 1);
                } else {
                    sparseIntArray.put(predict, sparseIntArray.get(predict) + 1);
                }
            }
        }
        if (sparseIntArray.size() == 0) {
            return null;
        }
        SparseArray<Float> sparseArray = new SparseArray<>();
        for (int i = 0; i < sparseIntArray.size(); i++) {
            int keyAt = sparseIntArray.keyAt(i);
            float f = sparseIntArray.get(keyAt) / 50.0f;
            sparseArray.put(keyAt, Float.valueOf(f));
            OplusLog.d(TAG, "label is: " + keyAt + " probability is : " + f);
        }
        return sparseArray;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.mForest != null) {
            int i = 0;
            while (true) {
                DecisionTree[] decisionTreeArr = this.mForest;
                if (i >= decisionTreeArr.length) {
                    break;
                }
                if (decisionTreeArr[i] != null) {
                    sb.append("mForest_");
                    sb.append(i);
                    sb.append(EventType.EventAssociationExtra.JOINT);
                    sb.append(this.mForest[i].toString());
                    sb.append("\t");
                }
                i++;
            }
        }
        return sb.toString();
    }

    public void train(int i, int[][] iArr, int[] iArr2) {
        if (iArr == null || iArr2 == null || i <= 0) {
            OplusLog.w(TAG, "Input parameter is illegal, please check!");
            return;
        }
        OplusLog.d(TAG, "feature.length = " + iArr.length + " mLabels.length = " + iArr2.length);
        this.mForest = new DecisionTree[this.mTreeNum];
        for (int i2 = 0; i2 < this.mTreeNum; i2++) {
            BaggingInstances baggingInstances = baggingInstances(iArr, iArr2);
            this.mForest[i2] = new DecisionTree(i, baggingInstances.f4634a, baggingInstances.c);
            this.mForest[i2].train();
        }
    }
}
