/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.svm.training;

import org.encog.EncogError;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.svm.training.EncodeSVMProblem;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.logging.EncogLogging;

public class SVMTrain
extends BasicTraining {
    private final SVM network;
    private svm_problem problem;
    private int fold = 0;
    private boolean trainingDone;
    private double gamma;
    private double c;

    public SVMTrain(SVM method, MLDataSet dataSet) {
        super(TrainingImplementationType.OnePass);
        this.network = method;
        this.setTraining(dataSet);
        this.trainingDone = false;
        this.problem = EncodeSVMProblem.encode(dataSet, 0);
        this.gamma = 1.0 / (double)this.network.getInputCount();
        this.c = 1.0;
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    private double evaluate(svm_parameter param, svm_problem prob, double[] target) {
        int totalCorrect = 0;
        ErrorCalculation error = new ErrorCalculation();
        if (param.svm_type == 3 || param.svm_type == 4) {
            int i = 0;
            while (i < prob.l) {
                double ideal = prob.y[i];
                double actual = target[i];
                error.updateError(actual, ideal);
                ++i;
            }
            return error.calculate();
        }
        int i = 0;
        while (i < prob.l) {
            if (target[i] == prob.y[i]) {
                ++totalCorrect;
            }
            ++i;
        }
        return 100.0 * (double)totalCorrect / (double)prob.l;
    }

    public double getC() {
        return this.c;
    }

    public int getFold() {
        return this.fold;
    }

    public double getGamma() {
        return this.gamma;
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    public svm_problem getProblem() {
        return this.problem;
    }

    @Override
    public boolean isTrainingDone() {
        return this.trainingDone;
    }

    @Override
    public void iteration() {
        this.network.getParams().C = this.c;
        this.network.getParams().gamma = this.gamma;
        EncogLogging.log(1, "Training with parameters C = " + this.c + ", gamma = " + this.gamma);
        if (this.fold > 1) {
            double[] target = new double[this.problem.l];
            svm.svm_cross_validation(this.problem, this.network.getParams(), this.fold, target);
            this.network.setModel(null);
            this.setError(this.evaluate(this.network.getParams(), this.problem, target));
        } else {
            this.network.setModel(svm.svm_train(this.problem, this.network.getParams()));
            this.setError(this.network.calculateError(this.getTraining()));
        }
        this.trainingDone = true;
    }

    @Override
    public final TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    public void setC(double theC) {
        this.c = theC;
        if (this.c <= 0.0 || this.c < 1.0E-13) {
            throw new EncogError("SVM training cannot use a c value less than zero.");
        }
    }

    public void setFold(int theFold) {
        this.fold = theFold;
    }

    public void setGamma(double theGamma) {
        this.gamma = theGamma;
        if (this.gamma <= 0.0 || this.gamma < 1.0E-13) {
            throw new EncogError("SVM training cannot use a gamma value less than zero.");
        }
    }
}

