package jsat.datatransform;

import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

/* loaded from: input_file:jsat/datatransform/DataModelPipeline.class */
public class DataModelPipeline implements Classifier, Regressor, Parameterized {
    private static final long serialVersionUID = -2300996837897094414L;

    @Parameter.ParameterHolder(skipSelfNamePrefix = true)
    private DataTransformProcess baseDtp;
    private Classifier baseClassifier;
    private Regressor baseRegressor;
    private DataTransformProcess learnedDtp;
    private Classifier learnedClassifier;
    private Regressor learnedRegressor;

    public DataModelPipeline(Classifier classifier, DataTransformProcess dataTransformProcess) {
        this.baseDtp = dataTransformProcess;
        this.baseClassifier = classifier;
        if (classifier instanceof Regressor) {
            this.baseRegressor = (Regressor) classifier;
        }
    }

    public DataModelPipeline(Classifier classifier, DataTransform... dataTransformArr) {
        this(classifier, new DataTransformProcess(dataTransformArr));
    }

    public DataModelPipeline(Regressor regressor, DataTransformProcess dataTransformProcess) {
        this.baseDtp = dataTransformProcess;
        this.baseRegressor = regressor;
        if (regressor instanceof Classifier) {
            this.baseClassifier = (Classifier) regressor;
        }
    }

    public DataModelPipeline(Regressor regressor, DataTransform... dataTransformArr) {
        this(regressor, new DataTransformProcess(dataTransformArr));
    }

    public DataModelPipeline(DataModelPipeline dataModelPipeline) {
        this.baseDtp = dataModelPipeline.baseDtp.clone();
        if (dataModelPipeline.baseClassifier != null && dataModelPipeline.baseClassifier == dataModelPipeline.baseRegressor) {
            this.baseClassifier = dataModelPipeline.baseClassifier.clone();
            this.baseRegressor = (Regressor) this.baseClassifier;
        } else if (dataModelPipeline.baseClassifier != null) {
            this.baseClassifier = dataModelPipeline.baseClassifier.clone();
        } else {
            if (dataModelPipeline.baseRegressor == null) {
                throw new RuntimeException("BUG: Report Me!");
            }
            this.baseRegressor = dataModelPipeline.baseRegressor.clone();
        }
        if (dataModelPipeline.learnedDtp != null) {
            this.learnedDtp = dataModelPipeline.learnedDtp.clone();
        }
        if (dataModelPipeline.learnedClassifier != null) {
            this.learnedClassifier = dataModelPipeline.learnedClassifier.clone();
        }
        if (dataModelPipeline.learnedRegressor != null) {
            this.learnedRegressor = dataModelPipeline.learnedRegressor.clone();
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return this.learnedClassifier.classify(this.learnedDtp.transform(dataPoint));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [jsat.DataSet, jsat.classifiers.ClassificationDataSet] */
    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        this.learnedDtp = this.baseDtp.clone();
        ?? shallowClone2 = classificationDataSet.shallowClone2();
        this.learnedDtp.learnApplyTransforms(shallowClone2);
        this.learnedClassifier = this.baseClassifier.clone();
        this.learnedClassifier.train(shallowClone2, z);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        if (this.baseClassifier != null) {
            return this.baseClassifier.supportsWeightedData();
        }
        if (this.baseRegressor != null) {
            return this.baseRegressor.supportsWeightedData();
        }
        throw new RuntimeException("BUG: Report Me! This should not have happened");
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return this.learnedRegressor.regress(this.learnedDtp.transform(dataPoint));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [jsat.DataSet, jsat.regression.RegressionDataSet] */
    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        this.learnedDtp = this.baseDtp.clone();
        ?? shallowClone2 = regressionDataSet.shallowClone2();
        this.learnedDtp.learnApplyTransforms(shallowClone2);
        this.learnedRegressor = this.baseRegressor.clone();
        this.learnedRegressor.train(shallowClone2, z);
    }

    @Override // jsat.regression.Regressor
    public DataModelPipeline clone() {
        return new DataModelPipeline(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        List<Parameter> paramsFromMethods = Parameter.getParamsFromMethods(this);
        if (this.baseClassifier != null && (this.baseClassifier instanceof Parameterized)) {
            paramsFromMethods.addAll(((Parameterized) this.baseClassifier).getParameters());
        } else if (this.baseRegressor != null && (this.baseRegressor instanceof Parameterized)) {
            paramsFromMethods.addAll(((Parameterized) this.baseRegressor).getParameters());
        }
        return paramsFromMethods;
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
