package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/training/loss/QuantileL1Loss.class */
public class QuantileL1Loss extends Loss {
    private Number quantile;

    public QuantileL1Loss(float f) {
        this("QuantileL1Loss", f);
    }

    public QuantileL1Loss(String str, float f) {
        super(str);
        this.quantile = Float.valueOf(f);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        NDArray singletonOrThrow = nDList2.singletonOrThrow();
        NDArray reshape = nDList.singletonOrThrow().reshape(singletonOrThrow.getShape());
        return singletonOrThrow.sub(reshape).mul(reshape.lte(singletonOrThrow).toType(DataType.FLOAT32, false).sub(this.quantile)).abs().mul((Number) 2).mean();
    }
}
