package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.utils.PairedReturn;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/classifiers/neuralnetwork/Perceptron.class */
public class Perceptron implements BinaryScoreClassifier, SingleWeightVectorModel {
    private static final long serialVersionUID = -3605237847981632021L;
    private double learningRate;
    private double bias;
    private Vec weights;
    private int iteratinLimit;

    /* loaded from: input_file:jsat/classifiers/neuralnetwork/Perceptron$BatchTrainingUnit.class */
    private class BatchTrainingUnit implements Callable<PairedReturn<Vec, Double[]>> {
        private Vec tmpSummedErrors;
        List<DataPointPair<Integer>> dataPoints;
        private double globalError = 0.0d;
        private double biasChange = 0.0d;

        public BatchTrainingUnit(List<DataPointPair<Integer>> list) {
            this.tmpSummedErrors = new DenseVector(Perceptron.this.weights.length());
            this.dataPoints = list;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public PairedReturn<Vec, Double[]> call() throws Exception {
            for (DataPointPair<Integer> dataPointPair : this.dataPoints) {
                double intValue = dataPointPair.getPair().intValue() - Perceptron.this.output(dataPointPair.getDataPoint());
                if (intValue != 0.0d) {
                    double weight = dataPointPair.getDataPoint().getWeight();
                    double d = Perceptron.this.learningRate * intValue * weight;
                    this.tmpSummedErrors.mutableAdd(d, dataPointPair.getVector());
                    this.biasChange += d;
                    this.globalError += Math.abs(intValue) * weight;
                }
            }
            return new PairedReturn<>(this.tmpSummedErrors, new Double[]{Double.valueOf(this.biasChange), Double.valueOf(this.globalError)});
        }
    }

    public Perceptron() {
        this(0.1d, 400);
    }

    public Perceptron(double d, int i) {
        if (d <= 0.0d || d > 1.0d) {
            throw new RuntimeException("Preceptron learning rate must be in the range (0,1]");
        }
        this.learningRate = d;
        this.iteratinLimit = i;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        categoricalResults.setProb(output(dataPoint), 1.0d);
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.weights.dot(dataPoint.getNumericalValues()) + this.bias;
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("Preceptron only supports binary calssification");
        }
        if (classificationDataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("Preceptron only supports vector classification");
        }
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getAsDPPList();
        Collections.shuffle(asDPPList);
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        Random random = RandomUtil.getRandom();
        this.weights = new DenseVector(classificationDataSet.getNumNumericalVars());
        for (int i = 0; i < this.weights.length(); i++) {
            this.weights.set(i, random.nextDouble());
        }
        Vec vec = null;
        double d = Double.MAX_VALUE;
        int i2 = 0;
        this.bias = 0.0d;
        ExecutorService newExecutor = ParallelUtils.getNewExecutor(z);
        do {
            double d2 = 0.0d;
            DenseVector denseVector = new DenseVector(this.weights.length());
            double d3 = 0.0d;
            ArrayList arrayList = new ArrayList(availableProcessors);
            int size = asDPPList.size() / availableProcessors;
            int i3 = 0;
            while (i3 < availableProcessors) {
                arrayList.add(newExecutor.submit(new BatchTrainingUnit(i3 == availableProcessors - 1 ? asDPPList.subList(i3 * size, asDPPList.size()) : asDPPList.subList(i3 * size, (i3 + 1) * size))));
                i3++;
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                try {
                    PairedReturn pairedReturn = (PairedReturn) ((Future) it.next()).get();
                    denseVector.mutableAdd((Vec) pairedReturn.getFirstItem());
                    d3 += ((Double[]) pairedReturn.getSecondItem())[0].doubleValue();
                    d2 += ((Double[]) pairedReturn.getSecondItem())[1].doubleValue();
                } catch (InterruptedException e) {
                } catch (ExecutionException e2) {
                }
            }
            if (d2 < d) {
                vec = this.weights;
                d = d2;
            }
            this.bias += d3;
            this.weights.mutableAdd(denseVector);
            i2++;
            if (d2 <= 0.0d) {
                break;
            }
        } while (i2 < this.iteratinLimit);
        this.weights = vec;
        newExecutor.shutdownNow();
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet) {
        trainCOnline(classificationDataSet);
    }

    public void trainCOnline(ClassificationDataSet classificationDataSet) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("Preceptron only supports binary calssification");
        }
        if (classificationDataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("Preceptron only supports vector classification");
        }
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getAsDPPList();
        Collections.shuffle(asDPPList);
        Random random = RandomUtil.getRandom();
        this.weights = new DenseVector(classificationDataSet.getNumNumericalVars());
        for (int i = 0; i < this.weights.length(); i++) {
            this.weights.set(i, random.nextDouble());
        }
        Vec vec = null;
        double d = Double.MAX_VALUE;
        int i2 = 0;
        do {
            double d2 = 0.0d;
            for (DataPointPair<Integer> dataPointPair : asDPPList) {
                double intValue = dataPointPair.getPair().intValue() - output(dataPointPair.getDataPoint());
                if (intValue != 0.0d) {
                    double weight = dataPointPair.getDataPoint().getWeight();
                    double d3 = this.learningRate * intValue * weight;
                    this.weights.mutableAdd(d3, dataPointPair.getVector());
                    this.bias += d3;
                    d2 += Math.abs(intValue) * weight;
                }
            }
            if (d2 < d) {
                vec = this.weights;
                d = d2;
            }
            i2++;
            if (d2 <= 0.0d) {
                break;
            }
        } while (i2 < this.iteratinLimit);
        this.weights = vec;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int output(DataPoint dataPoint) {
        return getScore(dataPoint) >= 0.0d ? 1 : 0;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.weights;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.bias;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier, jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Perceptron m73clone() {
        Perceptron perceptron = new Perceptron(this.learningRate, this.iteratinLimit);
        if (this.weights != null) {
            perceptron.weights = this.weights.mo46clone();
        }
        perceptron.bias = this.bias;
        return perceptron;
    }
}
