/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.meta;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import org.jdmp.core.algorithm.regression.AbstractRegressor;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.DefaultListDataSet;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.DefaultSample;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.collections.list.FastArrayList;
import org.ujmp.core.util.MathUtil;

public class FeatureSelector
extends AbstractRegressor {
    private static final long serialVersionUID = -4290061842734681590L;
    private final Regressor learningAlgorithm;
    private final int selectedFeatureCount;
    private final List<Integer> selectedFeatures = new FastArrayList<Integer>();
    private final SelectionType selectionType;

    public FeatureSelector(SelectionType selectionType, Regressor learningAlgorithm, int featureCount) {
        this.selectionType = selectionType;
        this.learningAlgorithm = learningAlgorithm;
        this.selectedFeatureCount = featureCount;
    }

    private void selectFeatures(ListDataSet dataSet) {
        if (this.selectionType == SelectionType.Random) {
            this.selectedFeatures.addAll(MathUtil.sequenceListInt(0, this.getFeatureCount(dataSet)));
            while (this.selectedFeatures.size() > this.selectedFeatureCount) {
                this.selectedFeatures.remove(MathUtil.nextInteger(this.selectedFeatures.size()));
            }
            return;
        }
        int featureCount = this.getFeatureCount(dataSet);
        int targetCount = this.getClassCount(dataSet);
        Matrix m = this.createCompleteMatrix(dataSet);
        Matrix comp = this.selectionType == SelectionType.Covariance ? m.cov(Calculation.Ret.NEW, true, true) : (this.selectionType == SelectionType.MutualInformation ? m.mutualInf(Calculation.Ret.NEW) : m.cov(Calculation.Ret.NEW, true, true));
        HashSet<Integer> featureSet = new HashSet<Integer>();
        while (featureSet.size() < this.selectedFeatureCount) {
            for (int t = 0; t < targetCount && featureSet.size() < this.selectedFeatureCount; ++t) {
                double bestValue = 0.0;
                int bestFeature = -1;
                for (int f = 0; f < featureCount; ++f) {
                    long[] lArray = new long[]{f, featureCount + t};
                    double val = Math.abs(comp.getAsDouble(lArray));
                    if (!(val > bestValue) || featureSet.contains(f)) continue;
                    bestValue = val;
                    bestFeature = f;
                }
                featureSet.add(bestFeature);
            }
        }
        this.selectedFeatures.addAll(featureSet);
        Collections.sort(this.selectedFeatures);
        System.out.println(this.selectedFeatures);
    }

    private Matrix createCompleteMatrix(ListDataSet dataSet) {
        int sampleCount = dataSet.size();
        int featureCount = this.getFeatureCount(dataSet);
        int targetCount = this.getClassCount(dataSet);
        DenseMatrix m = Matrix.Factory.zeros((long)sampleCount, (long)(featureCount + targetCount));
        for (int r = 0; r < sampleCount; ++r) {
            int c;
            Sample s = (Sample)dataSet.get(r);
            Matrix input = s.getAsMatrix(this.getInputLabel()).toColumnVector(Calculation.Ret.NEW);
            Matrix target = s.getAsMatrix(this.getTargetLabel()).toColumnVector(Calculation.Ret.NEW);
            for (c = 0; c < featureCount; ++c) {
                m.setAsDouble(input.getAsDouble(0L, c), r, c);
            }
            for (c = 0; c < targetCount; ++c) {
                m.setAsDouble(target.getAsDouble(0L, c), r, c + featureCount);
            }
        }
        return m;
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        this.selectFeatures(dataSet);
        DefaultListDataSet newDataSet = new DefaultListDataSet();
        for (Sample s1 : dataSet) {
            DefaultSample s2 = new DefaultSample();
            Matrix input1 = s1.getAsMatrix(this.getInputLabel()).toColumnVector(Calculation.Ret.NEW);
            DenseMatrix input2 = Matrix.Factory.zeros(1L, (long)this.selectedFeatures.size());
            for (int i = 0; i < this.selectedFeatures.size(); ++i) {
                input2.setAsDouble(input1.getAsDouble(0L, this.selectedFeatures.get(i).intValue()), 0L, i);
            }
            s2.put(this.getInputLabel(), input2);
            s2.put(this.getTargetLabel(), s1.getAsMatrix(this.getTargetLabel()));
            newDataSet.add(s2);
        }
        this.learningAlgorithm.trainAll(newDataSet);
    }

    @Override
    public void reset() {
        this.learningAlgorithm.reset();
        this.selectedFeatures.clear();
    }

    @Override
    public Matrix predictOne(Matrix input) {
        input = input.toColumnVector(Calculation.Ret.NEW);
        DenseMatrix input2 = Matrix.Factory.zeros(1L, (long)this.selectedFeatures.size());
        for (int i = 0; i < this.selectedFeatures.size(); ++i) {
            input2.setAsDouble(input.getAsDouble(0L, this.selectedFeatures.get(i).intValue()), 0L, i);
        }
        return this.learningAlgorithm.predictOne(input2);
    }

    @Override
    public Regressor emptyCopy() {
        FeatureSelector r = new FeatureSelector(this.selectionType, this.learningAlgorithm.emptyCopy(), this.selectedFeatureCount);
        r.setInputLabel(this.getInputLabel());
        r.setTargetLabel(this.getTargetLabel());
        return r;
    }

    public static enum SelectionType {
        Random,
        MutualInformation,
        Covariance;

    }
}

