package ai.djl.nn.transformer;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Collections;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/nn/transformer/IdEmbedding.class */
public final class IdEmbedding extends AbstractBlock {
    private static final String EMBEDDING_PARAM_NAME = "embedding";
    private int dictionarySize;
    private int embeddingSize;
    private Parameter embedding;

    /* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/nn/transformer/IdEmbedding$Builder.class */
    public static final class Builder {
        private int dictionarySize;
        private int embeddingSize;

        public Builder setDictionarySize(int i) {
            this.dictionarySize = i;
            return this;
        }

        public Builder setEmbeddingSize(int i) {
            this.embeddingSize = i;
            return this;
        }

        public IdEmbedding build() {
            if (this.dictionarySize <= 0) {
                throw new IllegalArgumentException("You must specify the dictionary Size for the embedding.");
            }
            if (this.embeddingSize == 0) {
                throw new IllegalArgumentException("You must specify the embedding size");
            }
            return new IdEmbedding(this);
        }
    }

    private IdEmbedding(Builder builder) {
        this.dictionarySize = builder.dictionarySize;
        this.embeddingSize = builder.embeddingSize;
        this.embedding = addParameter(Parameter.builder().setName(EMBEDDING_PARAM_NAME).setType(Parameter.Type.WEIGHT).optShape(new Shape(this.dictionarySize, this.embeddingSize)).build());
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{shapeArr[0].addAll(new Shape(this.embeddingSize))};
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        NDManager subManagerOf = NDManager.subManagerOf(singletonOrThrow);
        try {
            NDArray reshape = singletonOrThrow.flatten().reshape(1, singletonOrThrow.getShape().size());
            NDArray value = parameterStore.getValue(this.embedding, reshape.getDevice(), z);
            subManagerOf.tempAttachAll(value);
            NDArray gatherNd = value.gatherNd(reshape);
            gatherNd.attach(nDList.getManager());
            NDList nDList2 = new NDList(gatherNd.reshape(singletonOrThrow.getShape().addAll(new Shape(value.getShape().get(1)))));
            if (subManagerOf != null) {
                subManagerOf.close();
            }
            return nDList2;
        } catch (Throwable th) {
            if (subManagerOf != null) {
                try {
                    subManagerOf.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public NDArray probabilities(ParameterStore parameterStore, NDArray nDArray, boolean z) {
        NDArray reshape = nDArray.reshape(-1, this.embeddingSize);
        NDArray transpose = parameterStore.getValue(this.embedding, nDArray.getDevice(), z).transpose();
        transpose.attach(nDArray.getManager());
        return reshape.dot(transpose).logSoftmax(1).reshape(nDArray.getShape().slice(0, nDArray.getShape().dimension() - 1).addAll(new Shape(this.dictionarySize)));
    }

    public NDArray getValue(ParameterStore parameterStore, Device device, boolean z) {
        return parameterStore.getValue(this.embedding, device, z);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        this.inputNames = Collections.singletonList("tokenIds");
    }
}
