package org.ea.javacnn.losslayers;

import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;

/* loaded from: input_file:org/ea/javacnn/losslayers/RegressionLayer.class */
public class RegressionLayer extends LossLayer {
    public RegressionLayer(OutputDefinition outputDefinition) {
        super(outputDefinition);
    }

    @Override // org.ea.javacnn.layers.Layer
    public DataBlock forward(DataBlock dataBlock, boolean z) {
        this.in_act = dataBlock;
        this.out_act = dataBlock;
        return dataBlock;
    }

    public double backward(double[] dArr) {
        DataBlock dataBlock = this.in_act;
        dataBlock.clearGradient();
        double d = 0.0d;
        for (int i = 0; i < this.out_depth; i++) {
            double weight = dataBlock.getWeight(i) - dArr[i];
            dataBlock.setGradient(i, weight);
            d += 0.5d * weight * weight;
        }
        return d;
    }

    @Override // org.ea.javacnn.losslayers.LossLayer
    public double backward(int i) {
        DataBlock dataBlock = this.in_act;
        dataBlock.clearGradient();
        double weight = dataBlock.getWeight(0) - i;
        dataBlock.setGradient(0, weight);
        return 0.0d + (0.5d * weight * weight);
    }
}
