/*
 * Decompiled with CFR 0.152.
 */
package jsat.lossfunctions;

import jsat.classifiers.CategoricalResults;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.lossfunctions.LossMC;
import jsat.math.MathTricks;

public class SoftmaxLoss
extends LogisticLoss
implements LossMC {
    private static final long serialVersionUID = 3936898932252996024L;

    @Override
    public double getLoss(Vec processed, int y) {
        return -Math.log(processed.get(y));
    }

    @Override
    public void process(Vec pred, Vec processed) {
        if (pred != processed) {
            pred.copyTo(processed);
        }
        MathTricks.softmax(processed, false);
    }

    @Override
    public void deriv(Vec processed, Vec derivs, int y) {
        for (int i = 0; i < processed.length(); ++i) {
            if (i == y) {
                derivs.set(i, processed.get(i) - 1.0);
                continue;
            }
            derivs.set(i, processed.get(i));
        }
    }

    @Override
    public CategoricalResults getClassification(Vec processed) {
        return new CategoricalResults(processed.arrayCopy());
    }
}

