package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import java.util.Iterator;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/training/loss/ElasticNetWeightDecay.class */
public class ElasticNetWeightDecay extends Loss {
    private float lambda1;
    private float lambda2;
    private NDList parameters;

    public ElasticNetWeightDecay(NDList nDList) {
        this("ElasticNetWeightDecay", nDList);
    }

    public ElasticNetWeightDecay(String str, NDList nDList) {
        this(str, nDList, 1.0f);
    }

    public ElasticNetWeightDecay(String str, NDList nDList, float f) {
        super(str);
        this.lambda1 = f;
        this.lambda2 = f;
        this.parameters = nDList;
    }

    public ElasticNetWeightDecay(String str, NDList nDList, float f, float f2) {
        super(str);
        this.lambda1 = f;
        this.lambda2 = f2;
        this.parameters = nDList;
    }

    private NDArray l1(NDArray nDArray) {
        return nDArray.abs().sum();
    }

    private NDArray l2(NDArray nDArray) {
        return nDArray.square().sum();
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        NDManager manager = this.parameters.getManager();
        NDArray create = manager.create(0.0f);
        NDArray create2 = manager.create(0.0f);
        Iterator<NDArray> it = this.parameters.iterator();
        while (it.hasNext()) {
            NDArray next = it.next();
            create.addi(l1(next));
            create2.addi(l2(next));
        }
        return create.muli(Float.valueOf(this.lambda1)).addi(create2.muli(Float.valueOf(this.lambda2)));
    }
}
