package jsat.classifiers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/classifiers/ClassificationDataSet.class */
public class ClassificationDataSet extends DataSet<ClassificationDataSet> {
    protected CategoricalData predicting;
    protected List<DataPoint> datapoints;
    protected IntList category;
    private static final int[] emptyInt = new int[0];

    public ClassificationDataSet(DataSet dataSet, int i) {
        this(dataSet.getDataPoints(), i);
        if (this.numericalVariableNames == null) {
            this.numericalVariableNames = new ArrayList();
            for (int i2 = 0; i2 < getNumNumericalVars(); i2++) {
                this.numericalVariableNames.add("");
            }
        }
        for (int i3 = 0; i3 < getNumNumericalVars(); i3++) {
            this.numericalVariableNames.set(i3, dataSet.getNumericName(i3));
        }
    }

    public ClassificationDataSet(List<DataPoint> list, int i) {
        DataPoint dataPoint = list.get(0);
        this.categories = new CategoricalData[dataPoint.numCategoricalValues() - 1];
        int i2 = 0;
        while (i2 < this.categories.length) {
            this.categories[i2] = i2 >= i ? dataPoint.getCategoricalData()[i2 + 1] : dataPoint.getCategoricalData()[i2];
            i2++;
        }
        this.numNumerVals = dataPoint.numNumericalValues();
        this.predicting = dataPoint.getCategoricalData()[i];
        this.datapoints = new ArrayList(list.size());
        this.category = new IntList(list.size());
        for (DataPoint dataPoint2 : list) {
            int[] iArr = new int[dataPoint2.numCategoricalValues() - 1];
            int[] categoricalValues = dataPoint2.getCategoricalValues();
            int i3 = 0;
            for (int i4 = 0; i4 < categoricalValues.length; i4++) {
                if (i4 != i) {
                    int i5 = i3;
                    i3++;
                    iArr[i5] = categoricalValues[i4];
                }
            }
            this.datapoints.add(new DataPoint(dataPoint2.getNumericalValues(), iArr, this.categories, dataPoint2.getWeight()));
            this.category.add(categoricalValues[i]);
        }
        generateGenericNumericNames();
    }

    public ClassificationDataSet(List<DataPointPair<Integer>> list, CategoricalData categoricalData) {
        this.predicting = categoricalData;
        this.numNumerVals = list.get(0).getVector().length();
        this.categories = CategoricalData.copyOf(list.get(0).getDataPoint().getCategoricalData());
        this.datapoints = new ArrayList(list.size());
        this.category = new IntList(list.size());
        for (DataPointPair<Integer> dataPointPair : list) {
            this.datapoints.add(dataPointPair.getDataPoint());
            this.category.add(dataPointPair.getPair());
        }
        generateGenericNumericNames();
    }

    public ClassificationDataSet(int i, CategoricalData[] categoricalDataArr, CategoricalData categoricalData) {
        this.predicting = categoricalData;
        this.categories = categoricalDataArr;
        this.numNumerVals = i;
        this.datapoints = new ArrayList();
        this.category = new IntList();
        generateGenericNumericNames();
    }

    private void generateGenericNumericNames() {
        if (getNumNumericalVars() > 100) {
            return;
        }
        this.numericalVariableNames = new ArrayList(getNumNumericalVars());
        for (int i = 0; i < getNumNumericalVars(); i++) {
            this.numericalVariableNames.add("Numeric Input " + (i + 1));
        }
    }

    public int getClassSize() {
        return this.predicting.getNumOfCategories();
    }

    public static ClassificationDataSet comineAllBut(List<ClassificationDataSet> list, int i) {
        ClassificationDataSet classificationDataSet = new ClassificationDataSet(list.get(i).getNumNumericalVars(), list.get(i).getCategories(), list.get(i).getPredicting());
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (i2 != i) {
                classificationDataSet.datapoints.addAll(list.get(i2).datapoints);
                classificationDataSet.category.addAll(list.get(i2).category);
            }
        }
        return classificationDataSet;
    }

    @Override // jsat.DataSet
    public DataPoint getDataPoint(int i) {
        return getDataPointPair(i).getDataPoint();
    }

    public DataPointPair<Integer> getDataPointPair(int i) {
        if (i >= getSampleSize()) {
            throw new IndexOutOfBoundsException("There are not that many samples in the data set");
        }
        return new DataPointPair<>(this.datapoints.get(i), this.category.get(i));
    }

    @Override // jsat.DataSet
    public void setDataPoint(int i, DataPoint dataPoint) {
        if (i >= getSampleSize()) {
            throw new IndexOutOfBoundsException("There are not that many samples in the data set");
        }
        this.datapoints.set(i, dataPoint);
        this.columnVecCache.clear();
    }

    public int getDataPointCategory(int i) {
        if (i >= getSampleSize()) {
            throw new IndexOutOfBoundsException("There are not that many samples in the data set: " + i);
        }
        if (i < 0) {
            throw new IndexOutOfBoundsException("Can not specify negative index " + i);
        }
        return this.category.get(i).intValue();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // jsat.DataSet
    protected ClassificationDataSet getSubset(List<Integer> list) {
        ClassificationDataSet classificationDataSet = new ClassificationDataSet(this.numNumerVals, this.categories, this.predicting);
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            classificationDataSet.addDataPoint(getDataPoint(intValue), getDataPointCategory(intValue));
        }
        return classificationDataSet;
    }

    public List<ClassificationDataSet> stratSet(int i, Random random) {
        ArrayList arrayList = new ArrayList();
        while (arrayList.size() < i) {
            arrayList.add(new ClassificationDataSet(this.numNumerVals, this.categories, this.predicting.m1clone()));
        }
        IntList intList = new IntList();
        int i2 = 0;
        for (int i3 = 0; i3 < getClassSize(); i3++) {
            List<DataPoint> samples = getSamples(i3);
            intList.clear();
            ListUtils.addRange(intList, 0, samples.size(), 1);
            Collections.shuffle(intList, random);
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                ((ClassificationDataSet) arrayList.get(i2)).addDataPoint(samples.get(it.next().intValue()), i3);
                i2 = (i2 + 1) % i;
            }
        }
        return arrayList;
    }

    public void addDataPoint(Vec vec, int[] iArr, int i) {
        addDataPoint(vec, iArr, i, 1.0d);
    }

    public void addDataPoint(Vec vec, int i) {
        addDataPoint(vec, emptyInt, i, 1.0d);
    }

    public void addDataPoint(Vec vec, int i, double d) {
        addDataPoint(vec, emptyInt, i, d);
    }

    public void addDataPoint(Vec vec, int[] iArr, int i, double d) {
        if (vec.length() != this.numNumerVals) {
            throw new RuntimeException("Data point does not contain enough numerical data points");
        }
        if (iArr.length != this.categories.length) {
            throw new RuntimeException("Data point does not contain enough categorical data points");
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (!this.categories[i2].isValidCategory(iArr[i2]) && iArr[i2] >= 0) {
                throw new IllegalArgumentException("Categoriy value given is invalid");
            }
        }
        this.datapoints.add(new DataPoint(vec, iArr, this.categories, d));
        this.category.add(i);
        this.columnVecCache.clear();
    }

    public void addDataPoint(DataPoint dataPoint, int i) {
        if (dataPoint.getNumericalValues().length() != this.numNumerVals) {
            throw new RuntimeException("Data point does not contain enough numerical data points");
        }
        if (dataPoint.getCategoricalValues().length != this.categories.length) {
            throw new RuntimeException("Data point does not contain enough categorical data points");
        }
        for (int i2 = 0; i2 < dataPoint.getCategoricalValues().length; i2++) {
            int i3 = dataPoint.getCategoricalValues()[i2];
            if (!this.categories[i2].isValidCategory(i3) && i3 >= 0) {
                throw new RuntimeException("Categoriy value given is invalid");
            }
        }
        this.datapoints.add(dataPoint);
        this.category.add(i);
        this.columnVecCache.clear();
    }

    public List<DataPoint> getSamples(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.category.size(); i2++) {
            if (this.category.getI(i2) == i) {
                arrayList.add(this.datapoints.get(i2));
            }
        }
        return arrayList;
    }

    public Vec getSampleVariableVector(int i, int i2) {
        List<DataPoint> samples = getSamples(i);
        DenseVector denseVector = new DenseVector(samples.size());
        for (int i3 = 0; i3 < denseVector.length(); i3++) {
            denseVector.set(i3, samples.get(i3).getNumericalValues().get(i2));
        }
        return denseVector;
    }

    public CategoricalData getPredicting() {
        return this.predicting;
    }

    public List<DataPointPair<Integer>> getAsDPPList() {
        ArrayList arrayList = new ArrayList(getSampleSize());
        for (int i = 0; i < getSampleSize(); i++) {
            arrayList.add(new DataPointPair(this.datapoints.get(i), this.category.get(i)));
        }
        return arrayList;
    }

    public List<DataPointPair<Double>> getAsFloatDPPList() {
        ArrayList arrayList = new ArrayList(getSampleSize());
        for (int i = 0; i < getSampleSize(); i++) {
            arrayList.add(new DataPointPair(this.datapoints.get(i), Double.valueOf(this.category.getI(i))));
        }
        return arrayList;
    }

    public double[] getPriors() {
        double[] dArr = new double[getClassSize()];
        double d = 0.0d;
        for (int i = 0; i < getSampleSize(); i++) {
            double weight = this.datapoints.get(i).getWeight();
            int i2 = this.category.getI(i);
            dArr[i2] = dArr[i2] + weight;
            d += weight;
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
        return dArr;
    }

    public int classSampleCount(int i) {
        int i2 = 0;
        Iterator<Integer> it = this.category.iterator();
        while (it.hasNext()) {
            if (it.next().intValue() == i) {
                i2++;
            }
        }
        return i2;
    }

    @Override // jsat.DataSet
    public int getSampleSize() {
        return this.datapoints.size();
    }

    @Override // jsat.DataSet
    /* renamed from: shallowClone, reason: merged with bridge method [inline-methods] */
    public DataSet<ClassificationDataSet> shallowClone2() {
        ClassificationDataSet classificationDataSet = new ClassificationDataSet(this.numNumerVals, this.categories, this.predicting.m1clone());
        classificationDataSet.datapoints.addAll(this.datapoints);
        classificationDataSet.category.addAll(this.category);
        classificationDataSet.columnVecCache.putAll(this.columnVecCache);
        return classificationDataSet;
    }

    @Override // jsat.DataSet
    public ClassificationDataSet getTwiceShallowClone() {
        return (ClassificationDataSet) super.getTwiceShallowClone();
    }

    @Override // jsat.DataSet
    protected /* bridge */ /* synthetic */ ClassificationDataSet getSubset(List list) {
        return getSubset((List<Integer>) list);
    }
}
