/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.DataTransformProcess;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class DataModelPipeline
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = -2300996837897094414L;
    @Parameter.ParameterHolder(skipSelfNamePrefix=true)
    private DataTransformProcess baseDtp;
    private Classifier baseClassifier;
    private Regressor baseRegressor;
    private DataTransformProcess learnedDtp;
    private Classifier learnedClassifier;
    private Regressor learnedRegressor;

    public DataModelPipeline(Classifier baseClassifier, DataTransformProcess dtp) {
        this.baseDtp = dtp;
        this.baseClassifier = baseClassifier;
        if (baseClassifier instanceof Regressor) {
            this.baseRegressor = (Regressor)((Object)baseClassifier);
        }
    }

    public DataModelPipeline(Classifier baseClassifier, DataTransform ... transforms) {
        this(baseClassifier, new DataTransformProcess(transforms));
    }

    public DataModelPipeline(Regressor baseRegressor, DataTransformProcess dtp) {
        this.baseDtp = dtp;
        this.baseRegressor = baseRegressor;
        if (baseRegressor instanceof Classifier) {
            this.baseClassifier = (Classifier)((Object)baseRegressor);
        }
    }

    public DataModelPipeline(Regressor baseRegressor, DataTransform ... transforms) {
        this(baseRegressor, new DataTransformProcess(transforms));
    }

    public DataModelPipeline(DataModelPipeline toCopy) {
        this.baseDtp = toCopy.baseDtp.clone();
        if (toCopy.baseClassifier != null && toCopy.baseClassifier == toCopy.baseRegressor) {
            this.baseClassifier = toCopy.baseClassifier.clone();
            this.baseRegressor = (Regressor)((Object)this.baseClassifier);
        } else if (toCopy.baseClassifier != null) {
            this.baseClassifier = toCopy.baseClassifier.clone();
        } else if (toCopy.baseRegressor != null) {
            this.baseRegressor = toCopy.baseRegressor.clone();
        } else {
            throw new RuntimeException("BUG: Report Me!");
        }
        if (toCopy.learnedDtp != null) {
            this.learnedDtp = toCopy.learnedDtp.clone();
        }
        if (toCopy.learnedClassifier != null) {
            this.learnedClassifier = toCopy.learnedClassifier.clone();
        }
        if (toCopy.learnedRegressor != null) {
            this.learnedRegressor = toCopy.learnedRegressor.clone();
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.learnedClassifier.classify(this.learnedDtp.transform(data));
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.learnedDtp = this.baseDtp.clone();
        dataSet = dataSet.shallowClone();
        this.learnedDtp.learnApplyTransforms(dataSet);
        this.learnedClassifier = this.baseClassifier.clone();
        this.learnedClassifier.train(dataSet, parallel);
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.baseClassifier != null) {
            return this.baseClassifier.supportsWeightedData();
        }
        if (this.baseRegressor != null) {
            return this.baseRegressor.supportsWeightedData();
        }
        throw new RuntimeException("BUG: Report Me! This should not have happened");
    }

    @Override
    public double regress(DataPoint data) {
        return this.learnedRegressor.regress(this.learnedDtp.transform(data));
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.learnedDtp = this.baseDtp.clone();
        dataSet = dataSet.shallowClone();
        this.learnedDtp.learnApplyTransforms(dataSet);
        this.learnedRegressor = this.baseRegressor.clone();
        this.learnedRegressor.train(dataSet, parallel);
    }

    @Override
    public DataModelPipeline clone() {
        return new DataModelPipeline(this);
    }

    @Override
    public List<Parameter> getParameters() {
        List<Parameter> params = Parameter.getParamsFromMethods(this);
        if (this.baseClassifier != null && this.baseClassifier instanceof Parameterized) {
            params.addAll(((Parameterized)((Object)this.baseClassifier)).getParameters());
        } else if (this.baseRegressor != null && this.baseRegressor instanceof Parameterized) {
            params.addAll(((Parameterized)((Object)this.baseRegressor)).getParameters());
        }
        return params;
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

