package ai.djl.training.loss;

import ai.djl.ndarray.NDList;
import ai.djl.training.evaluator.Evaluator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/training/loss/Loss.class */
public abstract class Loss extends Evaluator {
    private Map<String, Float> totalLoss;

    public Loss(String str) {
        super(str);
        this.totalLoss = new ConcurrentHashMap();
    }

    public static L1Loss l1Loss() {
        return new L1Loss();
    }

    public static L1Loss l1Loss(String str) {
        return new L1Loss(str);
    }

    public static L1Loss l1Loss(String str, float f) {
        return new L1Loss(str, f);
    }

    public static QuantileL1Loss quantileL1Loss(float f) {
        return new QuantileL1Loss(f);
    }

    public static QuantileL1Loss quantileL1Loss(String str, float f) {
        return new QuantileL1Loss(str, f);
    }

    public static L2Loss l2Loss() {
        return new L2Loss();
    }

    public static L2Loss l2Loss(String str) {
        return new L2Loss(str);
    }

    public static L2Loss l2Loss(String str, float f) {
        return new L2Loss(str, f);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss() {
        return new SigmoidBinaryCrossEntropyLoss();
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(String str) {
        return new SigmoidBinaryCrossEntropyLoss(str);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(String str, float f, boolean z) {
        return new SigmoidBinaryCrossEntropyLoss(str, f, z);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss() {
        return new SoftmaxCrossEntropyLoss();
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(String str) {
        return new SoftmaxCrossEntropyLoss(str);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(String str, float f, int i, boolean z, boolean z2) {
        return new SoftmaxCrossEntropyLoss(str, f, i, z, z2);
    }

    public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss() {
        return new MaskedSoftmaxCrossEntropyLoss();
    }

    public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(String str) {
        return new MaskedSoftmaxCrossEntropyLoss(str);
    }

    public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(String str, float f, int i, boolean z, boolean z2) {
        return new MaskedSoftmaxCrossEntropyLoss(str, f, i, z, z2);
    }

    public static HingeLoss hingeLoss() {
        return new HingeLoss();
    }

    public static HingeLoss hingeLoss(String str) {
        return new HingeLoss(str);
    }

    public static HingeLoss hingeLoss(String str, int i, float f) {
        return new HingeLoss(str, i, f);
    }

    public static L1WeightDecay l1WeightedDecay(NDList nDList) {
        return new L1WeightDecay(nDList);
    }

    public static L1WeightDecay l1WeightedDecay(String str, NDList nDList) {
        return new L1WeightDecay(str, nDList);
    }

    public static L1WeightDecay l1WeightedDecay(String str, float f, NDList nDList) {
        return new L1WeightDecay(str, nDList, f);
    }

    public static L2WeightDecay l2WeightedDecay(NDList nDList) {
        return new L2WeightDecay(nDList);
    }

    public static L2WeightDecay l2WeightedDecay(String str, NDList nDList) {
        return new L2WeightDecay(str, nDList);
    }

    public static L2WeightDecay l2WeightedDecay(String str, float f, NDList nDList) {
        return new L2WeightDecay(str, nDList, f);
    }

    public static ElasticNetWeightDecay elasticNetWeightedDecay(NDList nDList) {
        return new ElasticNetWeightDecay(nDList);
    }

    public static ElasticNetWeightDecay elasticNetWeightedDecay(String str, NDList nDList) {
        return new ElasticNetWeightDecay(str, nDList);
    }

    public static ElasticNetWeightDecay elasticNetWeightedDecay(String str, float f, NDList nDList) {
        return new ElasticNetWeightDecay(str, nDList, f);
    }

    public static ElasticNetWeightDecay elasticNetWeightedDecay(String str, float f, float f2, NDList nDList) {
        return new ElasticNetWeightDecay(str, nDList, f, f2);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void addAccumulator(String str) {
        this.totalInstances.put(str, 0L);
        this.totalLoss.put(str, Float.valueOf(0.0f));
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void updateAccumulator(String str, NDList nDList, NDList nDList2) {
        updateAccumulators(new String[]{str}, nDList, nDList2);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void updateAccumulators(String[] strArr, NDList nDList, NDList nDList2) {
        float f = evaluate(nDList, nDList2).sum().getFloat(new long[0]);
        for (String str : strArr) {
            this.totalInstances.compute(str, (str2, l) -> {
                return Long.valueOf(l.longValue() + 1);
            });
            this.totalLoss.compute(str, (str3, f2) -> {
                return Float.valueOf(f2.floatValue() + f);
            });
        }
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void resetAccumulator(String str) {
        this.totalInstances.compute(str, (str2, l) -> {
            return 0L;
        });
        this.totalLoss.compute(str, (str3, f) -> {
            return Float.valueOf(0.0f);
        });
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public float getAccumulator(String str) {
        Long l = this.totalInstances.get(str);
        if (l == null) {
            throw new IllegalArgumentException("No loss found at that path");
        }
        if (l.longValue() == 0) {
            return Float.NaN;
        }
        return this.totalLoss.get(str).floatValue() / ((float) this.totalInstances.get(str).longValue());
    }
}
