/*
 * Decompiled with CFR 0.152.
 */
package Catalano.MachineLearning.Classification;

import Catalano.MachineLearning.Classification.IClassifier;
import Catalano.MachineLearning.Dataset.DatasetClassification;
import Catalano.Math.Distances.IDivergence;
import Catalano.Math.Distances.SquaredEuclideanDistance;
import Catalano.Math.Matrix;

public class MinimumMeanDistance
implements IClassifier {
    private IDivergence divergence;
    private double[][] means;

    public MinimumMeanDistance() {
        this(new SquaredEuclideanDistance());
    }

    public MinimumMeanDistance(IDivergence divergence) {
        this.divergence = divergence;
    }

    @Override
    public void Learn(DatasetClassification dataset) {
        this.Learn(dataset.getInput(), dataset.getOutput());
    }

    @Override
    public void Learn(double[][] input, int[] output) {
        int j;
        int i;
        int classes = Matrix.Max(output) + 1;
        this.means = new double[classes][input[0].length];
        int[] groups = this.CountGroups(output, classes);
        for (i = 0; i < input.length; ++i) {
            for (j = 0; j < input[0].length; ++j) {
                double[] dArray = this.means[output[i]];
                int n = j;
                dArray[n] = dArray[n] + input[i][j];
            }
        }
        for (i = 0; i < this.means.length; ++i) {
            for (j = 0; j < this.means[0].length; ++j) {
                this.means[i][j] = this.means[i][j] / (double)groups[i];
            }
        }
    }

    @Override
    public int Predict(double[] feature) {
        double[] distance = new double[this.means.length];
        for (int i = 0; i < distance.length; ++i) {
            distance[i] = this.divergence.Compute(this.means[i], feature);
        }
        return Matrix.MinIndex(distance);
    }

    @Override
    public IClassifier clone() {
        try {
            return (IClassifier)super.clone();
        }
        catch (CloneNotSupportedException ex) {
            throw new IllegalArgumentException("Clone not supported: " + ex.getMessage());
        }
    }

    private int[] CountGroups(int[] labels, int classes) {
        int[] groups = new int[classes];
        for (int i = 0; i < labels.length; ++i) {
            int n = labels[i];
            groups[n] = groups[n] + 1;
        }
        return groups;
    }
}

