/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.regression;

import org.jdmp.core.algorithm.AbstractAlgorithm;
import org.jdmp.core.algorithm.Algorithm;
import org.jdmp.core.algorithm.basic.FMeasure;
import org.jdmp.core.algorithm.basic.Minus;
import org.jdmp.core.algorithm.basic.Precision;
import org.jdmp.core.algorithm.basic.Recall;
import org.jdmp.core.algorithm.basic.Sensitivity;
import org.jdmp.core.algorithm.basic.Specificity;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
import org.ujmp.core.intmatrix.DenseIntMatrix2D;
import org.ujmp.core.util.MathUtil;
import org.ujmp.core.util.concurrent.PFor;

public abstract class AbstractRegressor
extends AbstractAlgorithm
implements Regressor {
    private static final long serialVersionUID = 4674447558395794134L;
    public static final String OUTPUTERRORALGORITHM = "OutputErrorAlgorithm";
    public static final int TRAIN = 0;
    public static final int PREDICT = 1;
    public int mode = 0;
    private int iteration = 0;
    private String inputLabel;
    private String targetLabel;
    private String weightLabel;

    public AbstractRegressor(String inputLabel, String targetLabel, String weightLabel) {
        this.inputLabel = inputLabel;
        this.targetLabel = targetLabel;
        this.weightLabel = weightLabel;
        this.setAlgorithm(OUTPUTERRORALGORITHM, new Minus(new Variable[0]));
    }

    public AbstractRegressor(String inputLabel, String targetLabel) {
        this(inputLabel, targetLabel, "Weight");
    }

    public AbstractRegressor(String inputLabel) {
        this(inputLabel, "Target");
    }

    public AbstractRegressor() {
        this("Input");
    }

    @Override
    public int getClassCount(ListDataSet dataSet) {
        return (int)((Sample)dataSet.get(0)).getAsMatrix(this.getTargetLabel()).toRowVector(Calculation.Ret.NEW).getRowCount();
    }

    @Override
    public boolean isDiscrete(ListDataSet dataSet) {
        for (Sample s : dataSet) {
            Matrix input = s.getAsMatrix(this.getInputLabel());
            for (long[] c : input.availableCoordinates()) {
                if (MathUtil.isDiscrete(input.getAsDouble(c))) continue;
                return false;
            }
        }
        return true;
    }

    @Override
    public int getFeatureCount(ListDataSet dataSet) {
        return (int)((Sample)dataSet.iterator().next()).getAsMatrix(this.getInputLabel()).toRowVector(Calculation.Ret.NEW).getRowCount();
    }

    @Override
    public String getInputLabel() {
        return this.inputLabel;
    }

    public String getWeightLabel() {
        return this.weightLabel;
    }

    public void setWeightLabel(String weightLabel) {
        this.weightLabel = weightLabel;
    }

    public void setInputLabel(String inputLabel) {
        this.inputLabel = inputLabel;
    }

    @Override
    public String getTargetLabel() {
        return this.targetLabel;
    }

    public void setTargetLabel(String targetLabel) {
        this.targetLabel = targetLabel;
    }

    @Override
    public void trainOne(Matrix input, Matrix sampleWeight, Matrix targetOutput) {
        throw new RuntimeException("not supported");
    }

    @Override
    public final void predictOne(Sample sample) {
        Matrix predicted = this.predictOne(sample.getAsMatrix(this.getInputLabel()));
        predicted = predicted.toColumnVector(Calculation.Ret.LINK);
        sample.put("Predicted", predicted);
        if (sample.getAsMatrix(this.getTargetLabel()) != null) {
            Matrix target = sample.getAsMatrix(this.getTargetLabel()).toColumnVector(Calculation.Ret.LINK);
            Matrix error = target.minus(predicted);
            sample.put("Difference", error);
            sample.put("RMSE", Matrix.Factory.linkToValue(error.getRMS()));
        }
        sample.fireValueChanged();
    }

    @Override
    public final void trainOne(Matrix input, Matrix targetOutput) {
        this.trainOne(input, Matrix.Factory.linkToValue(1.0), targetOutput);
    }

    @Override
    public final void trainOne(Sample sample) {
        Matrix input = sample.getAsMatrix(this.getInputLabel());
        Matrix weight = sample.getAsMatrix(this.getWeightLabel());
        Matrix target = sample.getAsMatrix(this.getTargetLabel());
        this.trainOne(input, weight, target);
    }

    @Override
    public void predictAll(final ListDataSet dataSet) {
        new PFor(0, dataSet.size() - 1){

            @Override
            public void step(int i) {
                Sample sample = (Sample)dataSet.get(i);
                AbstractRegressor.this.predictOne(sample);
            }
        };
        if (((Sample)dataSet.get(0)).getAsMatrix(this.getTargetLabel()) != null) {
            double error = 0.0;
            int correctCount = 0;
            int errorCount = 0;
            int classCount = this.getClassCount(dataSet);
            DenseMatrix confusion = Matrix.Factory.zeros((long)classCount, (long)classCount);
            confusion.setDimensionLabel(0, "expected");
            confusion.setDimensionLabel(1, "predicted");
            for (Sample sample : dataSet) {
                double rmse = sample.getAsMatrix("RMSE").getEuklideanValue();
                error += Math.pow(rmse, 2.0);
                int recognized = sample.getRecognizedClass();
                int targetClass = sample.getTargetClass();
                if (classCount == 1 || recognized == -1) {
                    confusion.setAsDouble(confusion.getAsDouble(0L, 0L) + 1.0, 0L, 0L);
                } else {
                    confusion.setAsDouble(confusion.getAsDouble(targetClass, recognized) + 1.0, targetClass, recognized);
                }
                if (sample.isCorrect()) {
                    ++correctCount;
                    continue;
                }
                ++errorCount;
            }
            DenseDoubleMatrix2D rmse = Matrix.Factory.linkToValue(Math.sqrt(error / (double)dataSet.size()));
            rmse.setLabel("RMSE with " + this.getLabel());
            dataSet.setMatrix("RMSE", rmse);
            double maxRMSE = 0.0;
            double minRMSE = 1.0E100;
            double sumRMSE = 0.0;
            for (Sample s : dataSet) {
                Matrix m = s.getAsMatrix("RMSE");
                if (m == null) continue;
                double sampleRMSE = m.doubleValue();
                sumRMSE += sampleRMSE;
                if (sampleRMSE > maxRMSE) {
                    maxRMSE = sampleRMSE;
                }
                if (!(sampleRMSE < minRMSE)) continue;
                minRMSE = sampleRMSE;
            }
            confusion.setLabel("Confusion with " + this.getLabel());
            dataSet.setMatrix("Confusion", confusion);
            DenseDoubleMatrix2D accuracy = Matrix.Factory.linkToValue((double)correctCount / (double)dataSet.size());
            accuracy.setLabel("Accuracy with " + this.getLabel());
            dataSet.setMatrix("Accuracy", accuracy);
            DenseIntMatrix2D errorMatrix = Matrix.Factory.linkToValue(errorCount);
            errorMatrix.setLabel("Errors with " + this.getLabel());
            dataSet.setMatrix("ErrorCount", errorMatrix);
            Matrix sensitivity = (Matrix)new Sensitivity(new Variable[0]).calculate(confusion).get(this.getTargetLabel());
            dataSet.setMatrix("Sensitivity", sensitivity);
            Matrix specificity = (Matrix)new Specificity(new Variable[0]).calculate(confusion).get(this.getTargetLabel());
            dataSet.setMatrix("Specificity", specificity);
            Matrix precision = (Matrix)new Precision(new Variable[0]).calculate(confusion).get(this.getTargetLabel());
            dataSet.setMatrix("Precision", precision);
            Matrix recall = (Matrix)new Recall(new Variable[0]).calculate(confusion).get(this.getTargetLabel());
            dataSet.setMatrix("Recall", recall);
            Matrix fmeasure = (Matrix)new FMeasure(new Variable[0]).calculate(confusion).get(this.getTargetLabel());
            dataSet.setMatrix("FMeasure", fmeasure);
            Matrix fmeasureMacro = fmeasure.mean(Calculation.Ret.NEW, Integer.MAX_VALUE, false);
            dataSet.setMatrix("FMeasureMacro", fmeasureMacro);
            ++this.iteration;
            System.out.println("Iteration: " + this.iteration + ", RMSE: " + rmse.doubleValue() + ", errors: " + errorCount + ", accuracy: " + accuracy.doubleValue());
        }
        dataSet.fireValueChanged();
    }

    public Algorithm getOutputErrorAlgorithm() {
        return (Algorithm)this.getAlgorithmMap().get(OUTPUTERRORALGORITHM);
    }

    public void setOutputErrorAlgorithm(Algorithm a) {
        this.setAlgorithm(OUTPUTERRORALGORITHM, a);
    }

    public int getMode() {
        return this.mode;
    }

    public void setMode(int mode) {
        this.mode = mode;
    }
}

