/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.UntrainedModelException;
import jsat.math.Function1D;
import jsat.math.rootfinding.Zeroin;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.DoubleList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class StochasticGradientBoosting
implements Regressor,
Parameterized {
    private static final long serialVersionUID = -2855154397476855293L;
    public static final double DEFAULT_TRAINING_PROPORTION = 0.5;
    public static final double DEFAULT_LEARNING_RATE = 0.1;
    private double trainingProportion;
    private Regressor weakLearner;
    private Regressor strongLearner;
    private List<Regressor> F;
    private List<Double> coef;
    private double learningRate;
    private int maxIterations;

    public StochasticGradientBoosting(Regressor strongLearner, Regressor weakLearner, int maxIterations, double learningRate, double trainingPortion) {
        this.trainingProportion = trainingPortion;
        this.strongLearner = strongLearner;
        this.weakLearner = weakLearner;
        this.learningRate = learningRate;
        this.maxIterations = maxIterations;
    }

    public StochasticGradientBoosting(Regressor weakLearner, int maxIterations, double learningRate, double trainingPortion) {
        this(null, weakLearner, maxIterations, learningRate, trainingPortion);
    }

    public StochasticGradientBoosting(Regressor weakLearner, int maxIterations, double learningRate) {
        this(weakLearner, maxIterations, learningRate, 0.5);
    }

    public StochasticGradientBoosting(Regressor weakLearner, int maxIterations) {
        this(weakLearner, maxIterations, 0.1);
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setLearningRate(double learningRate) {
        if (learningRate > 1.0 || learningRate <= 0.0 || Double.isNaN(learningRate)) {
            throw new ArithmeticException("Invalid learning rate");
        }
        this.learningRate = learningRate;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setTrainingProportion(double trainingProportion) {
        if (trainingProportion > 1.0 || trainingProportion <= 0.0 || Double.isNaN(trainingProportion)) {
            throw new ArithmeticException("Training Proportion is invalid");
        }
        this.trainingProportion = trainingProportion;
    }

    public double getTrainingProportion() {
        return this.trainingProportion;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.F == null || this.F.isEmpty()) {
            throw new UntrainedModelException();
        }
        double result = 0.0;
        for (int i = 0; i < this.F.size(); ++i) {
            result += this.F.get(i).regress(data) * this.coef.get(i);
        }
        return result;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        List<DataPointPair<Double>> backingResidsList = dataSet.getAsDPPList();
        this.F = new ArrayList<Regressor>(this.maxIterations);
        this.coef = new DoubleList(this.maxIterations);
        Regressor lastF = this.strongLearner == null ? this.weakLearner.clone() : this.strongLearner.clone();
        lastF.train(dataSet, parallel);
        this.F.add(lastF);
        this.coef.add(this.learningRate * this.getMinimizingErrorConst(backingResidsList, lastF));
        double[] currPredictions = new double[dataSet.getSampleSize()];
        RegressionDataSet resids = RegressionDataSet.usingDPPList(backingResidsList);
        int randSampleSize = (int)Math.round((double)resids.getSampleSize() * this.trainingProportion);
        ArrayList<DataPointPair<Double>> randSampleList = new ArrayList<DataPointPair<Double>>(randSampleSize);
        Random rand = RandomUtil.getRandom();
        for (int iter = 0; iter < this.maxIterations; ++iter) {
            double lastCoef = this.coef.get(iter);
            lastF = this.F.get(iter);
            for (int j = 0; j < resids.getSampleSize(); ++j) {
                double lastFPred = lastF.regress(resids.getDataPoint(j));
                int n = j;
                currPredictions[n] = currPredictions[n] + lastCoef * lastFPred;
                resids.setTargetValue(j, dataSet.getTargetValue(j) - currPredictions[j]);
            }
            randSampleList.clear();
            ListUtils.randomSample(backingResidsList, randSampleList, randSampleSize, rand);
            Regressor h = this.weakLearner.clone();
            RegressionDataSet tmpDataSet = RegressionDataSet.usingDPPList(randSampleList);
            h.train(tmpDataSet, parallel);
            double y = this.getMinimizingErrorConst(backingResidsList, h);
            this.F.add(h);
            this.coef.add(this.learningRate * y);
        }
    }

    private double getMinimizingErrorConst(List<DataPointPair<Double>> backingResidsList, Regressor h) {
        Function1D fhPrime = this.getDerivativeFunc(backingResidsList, h);
        Zeroin rf = new Zeroin();
        double y = rf.root(1.0E-4, 50, new double[]{-2.5, 2.5}, fhPrime);
        return y;
    }

    private Function1D getDerivativeFunc(List<DataPointPair<Double>> backingResidsList, Regressor h) {
        Function1D fhPrime = x -> {
            double c1 = x;
            double eps = 1.0E-5;
            double c1Pc2 = c1 * 2.0 - eps;
            double result = 0.0;
            for (DataPointPair dpp : backingResidsList) {
                double hEst = h.regress(dpp.getDataPoint());
                double target = (Double)dpp.getPair();
                result += hEst * (c1Pc2 * hEst - 2.0 * target);
            }
            return result * eps;
        };
        return fhPrime;
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.strongLearner != null) {
            return this.strongLearner.supportsWeightedData() && this.weakLearner.supportsWeightedData();
        }
        return this.weakLearner.supportsWeightedData();
    }

    @Override
    public StochasticGradientBoosting clone() {
        StochasticGradientBoosting clone = new StochasticGradientBoosting(this.weakLearner.clone(), this.maxIterations, this.learningRate, this.trainingProportion);
        if (this.F != null) {
            clone.F = new ArrayList<Regressor>(this.F.size());
            for (Regressor f : this.F) {
                clone.F.add(f.clone());
            }
        }
        if (this.coef != null) {
            clone.coef = new DoubleList(this.coef);
        }
        if (this.strongLearner != null) {
            clone.strongLearner = this.strongLearner.clone();
        }
        return clone;
    }
}

