package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.Optional;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/nn/core/ConstantEmbedding.class */
public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedEmbedding {
    protected NDArray embedding;

    public ConstantEmbedding(NDArray nDArray) {
        this.embedding = nDArray;
        freezeParameters(true);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray create = nDList.get(0).getManager().create(this.embedding.getShape());
        this.embedding.copyTo(create);
        return new NDList(create.reshape(1, this.embedding.size()).repeat(0, nDList.get(0).size()).reshape(nDList.get(0).getShape().addAll(this.embedding.getShape())));
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{shapeArr[0].addAll(this.embedding.getShape())};
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) {
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) {
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public Optional<?> unembed(long j) {
        return Optional.empty();
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public byte[] encode(Object obj) {
        return new byte[0];
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public Object decode(byte[] bArr) {
        return null;
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public long embed(Object obj) {
        return 0L;
    }

    @Override // ai.djl.nn.core.AbstractEmbedding
    public NDArray embed(NDManager nDManager, Object[] objArr) {
        NDArray create = nDManager.create(this.embedding.getShape());
        this.embedding.copyTo(create);
        return create.repeat(0, objArr.length).reshape(new Shape(objArr.length).addAll(this.embedding.getShape()));
    }

    @Override // ai.djl.nn.core.AbstractEmbedding
    public boolean hasItem(Object obj) {
        return true;
    }
}
