package ai.djl.training.evaluator;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.Pair;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/training/evaluator/SingleShotDetectionAccuracy.class */
public class SingleShotDetectionAccuracy extends AbstractAccuracy {
    private MultiBoxTarget multiBoxTarget;

    public SingleShotDetectionAccuracy(String str) {
        super(str, 0);
        this.multiBoxTarget = MultiBoxTarget.builder().build();
    }

    @Override // ai.djl.training.evaluator.AbstractAccuracy
    protected Pair<Long, NDArray> accuracyHelper(NDList nDList, NDList nDList2) {
        NDArray nDArray = nDList2.get(0);
        NDArray nDArray2 = nDList2.get(1);
        NDArray nDArray3 = this.multiBoxTarget.target(new NDList(nDArray, nDList.head(), nDArray2.transpose(0, 2, 1))).get(2);
        checkLabelShapes(nDArray3, nDArray2);
        NDArray argMax = nDArray2.argMax(-1);
        return new Pair<>(Long.valueOf(nDArray3.size()), nDArray3.toType(DataType.INT64, false).eq(argMax).countNonzero());
    }
}
