package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Activation;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/training/loss/SigmoidBinaryCrossEntropyLoss.class */
public class SigmoidBinaryCrossEntropyLoss extends Loss {
    private float weight;
    private boolean fromSigmoid;

    public SigmoidBinaryCrossEntropyLoss() {
        this("SigmoidBinaryCrossEntropyLoss");
    }

    public SigmoidBinaryCrossEntropyLoss(String str) {
        this(str, 1.0f, false);
    }

    public SigmoidBinaryCrossEntropyLoss(String str, float f, boolean z) {
        super(str);
        this.weight = f;
        this.fromSigmoid = z;
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        NDArray singletonOrThrow = nDList2.singletonOrThrow();
        NDArray reshape = nDList.singletonOrThrow().reshape(singletonOrThrow.getShape());
        NDArray add = !this.fromSigmoid ? Activation.relu(singletonOrThrow).sub(singletonOrThrow.mul(reshape)).add(Activation.softPlus(singletonOrThrow.abs().neg())) : epsLog(singletonOrThrow).mul(reshape).add(epsLog(NDArrays.sub(Double.valueOf(1.0d), singletonOrThrow)).mul(NDArrays.sub(Double.valueOf(1.0d), reshape)));
        if (this.weight != 1.0f) {
            add = add.mul(Float.valueOf(this.weight));
        }
        return add.mean();
    }

    private NDArray epsLog(NDArray nDArray) {
        return nDArray.add(Double.valueOf(1.0E-12d)).log();
    }
}
