package com.oplus.deepthinker.internal.api.proton.learn.algorithm;

import android.content.Context;
import android.os.PersistableBundle;
import com.oplus.deepthinker.internal.api.proton.learn.data.DataSet;
import com.oplus.deepthinker.internal.api.utils.OplusLog;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

/* loaded from: classes2.dex */
public abstract class AbstractTrain<T> {

    /* renamed from: a, reason: collision with root package name */
    private AbstractModel<T> f4814a;

    /* renamed from: b, reason: collision with root package name */
    protected Executor f4815b;
    protected List<AbstractAlgorithm<T>> c = new ArrayList();
    protected final Object d = new Object();
    private PersistableBundle e;

    /* loaded from: classes2.dex */
    public static class ClassifiedDataSet<T> {
        public DataSet<T> mTrainDataSet;
        public DataSet<T> mValidationDataSet;
    }

    private AbstractModel<T> a(Context context, ClassifiedDataSet<T> classifiedDataSet) {
        double d = Double.MAX_VALUE;
        for (AbstractAlgorithm<T> abstractAlgorithm : this.c) {
            if (abstractAlgorithm != null) {
                ClassifiedDataSet<T> privateDataSet = abstractAlgorithm.getPrivateDataSet() != null ? abstractAlgorithm.getPrivateDataSet() : classifiedDataSet;
                if (privateDataSet == null) {
                    OplusLog.w("AbstractTrain", abstractAlgorithm + " does not have valid dataSet.");
                } else {
                    AbstractModel<T> train = abstractAlgorithm.train(context, privateDataSet.mTrainDataSet);
                    if (train != null) {
                        double computeCost = train.computeCost(privateDataSet.mValidationDataSet);
                        train.setCost(computeCost);
                        if (computeCost < d) {
                            this.f4814a = train;
                            d = computeCost;
                        }
                    }
                }
            }
        }
        return this.f4814a;
    }

    private AbstractModel<T> b(final Context context, ClassifiedDataSet<T> classifiedDataSet) {
        final CountDownLatch countDownLatch = new CountDownLatch(this.c.size());
        for (final AbstractAlgorithm<T> abstractAlgorithm : this.c) {
            if (abstractAlgorithm == null || (classifiedDataSet == null && abstractAlgorithm.getPrivateDataSet() == null)) {
                countDownLatch.countDown();
            } else {
                final ClassifiedDataSet<T> privateDataSet = abstractAlgorithm.getPrivateDataSet() != null ? abstractAlgorithm.getPrivateDataSet() : classifiedDataSet;
                this.f4815b.execute(new Runnable() { // from class: com.oplus.deepthinker.internal.api.proton.learn.algorithm.AbstractTrain.1
                    @Override // java.lang.Runnable
                    public void run() {
                        try {
                            AbstractModel<T> train = abstractAlgorithm.train(context, privateDataSet.mTrainDataSet);
                            if (train != null) {
                                double computeCost = train.computeCost(privateDataSet.mValidationDataSet);
                                train.setCost(computeCost);
                                synchronized (AbstractTrain.this.d) {
                                    if (computeCost < (AbstractTrain.this.f4814a != null ? AbstractTrain.this.f4814a.getCost() : Double.MAX_VALUE)) {
                                        AbstractTrain.this.f4814a = train;
                                    }
                                }
                            }
                        } finally {
                            countDownLatch.countDown();
                        }
                    }
                });
            }
        }
        try {
            countDownLatch.await(5L, TimeUnit.MINUTES);
        } catch (InterruptedException e) {
            OplusLog.e("AbstractTrain", "train latch interrupted!", e);
        }
        return this.f4814a;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract ClassifiedDataSet<T> a(Context context);

    public void addAlgorithm(AbstractAlgorithm abstractAlgorithm) {
        if (this.c == null) {
            this.c = new ArrayList();
        }
        this.c.add(abstractAlgorithm);
    }

    public PersistableBundle getArgs() {
        return this.e;
    }

    public void setArgs(PersistableBundle persistableBundle) {
        this.e = persistableBundle;
    }

    public void setExecutor(Executor executor) {
        this.f4815b = executor;
    }

    public List<AbstractModel<T>> train(Context context) {
        ClassifiedDataSet<T> a2 = a(context);
        AbstractModel<T> b2 = this.f4815b != null ? b(context, a2) : a(context, a2);
        OplusLog.i("AbstractTrain", "train over: result " + b2);
        if (b2 == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(b2);
        return arrayList;
    }
}
