package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/modality/nlp/generate/StepGeneration.class */
public final class StepGeneration {
    static final /* synthetic */ boolean $assertionsDisabled;

    private StepGeneration() {
    }

    public static NDList constrastiveStepGeneration(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, NDArray nDArray4, NDArray nDArray5, float f) {
        NDArray batchMatMul = nDArray4.reshape(nDArray.getShape().get(0), nDArray.getShape().get(1), nDArray4.getShape().getLastDimension()).normalize(2.0d, 2L).batchMatMul(nDArray3.normalize(2.0d, 2L).transpose(0, 2, 1));
        long[] longArray = nDArray5.toLongArray();
        for (int i = 0; i < longArray.length; i++) {
            batchMatMul.set(new NDIndex("{}, :, {}:{}", Integer.valueOf(i), 0, Long.valueOf(longArray[i])), (Number) (-1));
        }
        NDArray max = batchMatMul.max(new int[]{2});
        if (!$assertionsDisabled && max.getShape().getShape().length != 2) {
            throw new AssertionError("Wrong output size");
        }
        NDArray argMax = nDArray2.softmax(1).gather(nDArray, 1).muli(Float.valueOf(1.0f - f)).subi(max.muli(Float.valueOf(f))).argMax(1);
        return new NDList(nDArray.get(new NDIndex("{}, {}, ...", nDArray2.getManager().arange(0.0f, (float) nDArray.getShape().get(0), 1.0f, DataType.INT64), argMax)).reshape(-1, 1), argMax);
    }

    public static NDArray greedyStepGen(NDArray nDArray) {
        if ($assertionsDisabled || nDArray.getShape().getShape().length == 3) {
            return nDArray.get(":, -1, :", new Object[0]).argMax(-1).expandDims(1);
        }
        throw new AssertionError("unexpected input");
    }

    public static NDList beamStepGeneration(NDArray nDArray, NDArray nDArray2, long j, long j2) {
        NDList pKVar = nDArray2.get(":, -1, :", new Object[0]).softmax(1).reshape(j, j2, -1).topK(Math.toIntExact(j2), -1, true, false);
        NDArray nDArray3 = pKVar.get(1);
        NDArray muli = pKVar.get(0).muli(nDArray.reshape(j, j2, 1));
        NDArray nDArray4 = muli.reshape(j, j2 * j2).topK(Math.toIntExact(j2), -1, true, false).get(1);
        NDIndex nDIndex = new NDIndex("{}, {}, ...", nDArray2.getManager().arange(0.0f, (float) j, 1.0f, DataType.INT64).expandDims(1).repeat(1, j2), nDArray4);
        NDArray expandDims = nDArray3.reshape(j, j2 * j2).get(nDIndex).expandDims(2);
        NDArray normalize = muli.reshape(j, j2 * j2).get(nDIndex).normalize(1.0d, 1L);
        if (!$assertionsDisabled && nDArray4.getDataType() != DataType.INT64) {
            throw new AssertionError("Wrong output! Expect integer division");
        }
        if (!$assertionsDisabled && nDArray4.getShape().getShape().length != 2) {
            throw new AssertionError("Wrong size. Expect [batch, beamNew]");
        }
        long[] longArray = nDArray4.toLongArray();
        for (int i = 0; i < longArray.length; i++) {
            longArray[i] = Math.floorDiv(longArray[i], j2);
        }
        return new NDList(expandDims, normalize, nDArray2.getManager().create(longArray, new Shape(j, j2)));
    }

    static {
        $assertionsDisabled = !StepGeneration.class.desiredAssertionStatus();
    }
}
