package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.PairList;
import ai.djl.util.StringPair;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:lib/tokenizers-0.31.1.jar:ai/djl/huggingface/translator/CrossEncoderTranslator.class */
public class CrossEncoderTranslator implements Translator<StringPair, float[]> {
    private HuggingFaceTokenizer tokenizer;
    private boolean includeTokenTypes;
    private boolean sigmoid;
    private Batchifier batchifier;

    /* loaded from: input_file:lib/tokenizers-0.31.1.jar:ai/djl/huggingface/translator/CrossEncoderTranslator$Builder.class */
    public static final class Builder {
        private HuggingFaceTokenizer tokenizer;
        private boolean includeTokenTypes;
        private boolean sigmoid = true;
        private Batchifier batchifier = Batchifier.STACK;

        Builder(HuggingFaceTokenizer huggingFaceTokenizer) {
            this.tokenizer = huggingFaceTokenizer;
        }

        public Builder optIncludeTokenTypes(boolean z) {
            this.includeTokenTypes = z;
            return this;
        }

        public Builder optSigmoid(boolean z) {
            this.sigmoid = z;
            return this;
        }

        public Builder optBatchifier(Batchifier batchifier) {
            this.batchifier = batchifier;
            return this;
        }

        public void configure(Map<String, ?> map) {
            optIncludeTokenTypes(ArgumentsUtil.booleanValue(map, "includeTokenTypes"));
            optSigmoid(ArgumentsUtil.booleanValue(map, "sigmoid", true));
            optBatchifier(Batchifier.fromString(ArgumentsUtil.stringValue(map, "batchifier", "stack")));
        }

        public CrossEncoderTranslator build() throws IOException {
            return new CrossEncoderTranslator(this.tokenizer, this.includeTokenTypes, this.sigmoid, this.batchifier);
        }
    }

    CrossEncoderTranslator(HuggingFaceTokenizer huggingFaceTokenizer, boolean z, boolean z2, Batchifier batchifier) {
        this.tokenizer = huggingFaceTokenizer;
        this.includeTokenTypes = z;
        this.sigmoid = z2;
        this.batchifier = batchifier;
    }

    @Override // ai.djl.translate.Translator
    public Batchifier getBatchifier() {
        return this.batchifier;
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, StringPair stringPair) {
        Encoding encode = this.tokenizer.encode(stringPair.getKey(), stringPair.getValue());
        translatorContext.setAttachment("encoding", encode);
        return encode.toNDList(translatorContext.getNDManager(), this.includeTokenTypes);
    }

    @Override // ai.djl.translate.Translator
    public NDList batchProcessInput(TranslatorContext translatorContext, List<StringPair> list) {
        NDManager nDManager = translatorContext.getNDManager();
        Encoding[] batchEncode = this.tokenizer.batchEncode(new PairList<>(list));
        NDList[] nDListArr = new NDList[batchEncode.length];
        for (int i = 0; i < batchEncode.length; i++) {
            nDListArr[i] = batchEncode[i].toNDList(nDManager, this.includeTokenTypes);
        }
        return this.batchifier.batchify(nDListArr);
    }

    @Override // ai.djl.translate.PostProcessor
    public float[] processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray nDArray = nDList.get(0);
        if (this.sigmoid) {
            nDArray = nDArray.getNDArrayInternal().sigmoid();
        }
        return nDArray.toFloatArray();
    }

    @Override // ai.djl.translate.Translator
    public List<float[]> batchProcessOutput(TranslatorContext translatorContext, NDList nDList) {
        if (this.sigmoid) {
            NDList[] unbatchify = this.batchifier.unbatchify(nDList);
            ArrayList arrayList = new ArrayList(unbatchify.length);
            for (NDList nDList2 : unbatchify) {
                arrayList.add(nDList2.get(0).getNDArrayInternal().sigmoid().toFloatArray());
            }
            return arrayList;
        }
        int intExact = Math.toIntExact(nDList.get(0).size(0));
        float[] floatArray = nDList.get(0).toFloatArray();
        if (intExact == 1) {
            return Collections.singletonList(floatArray);
        }
        int length = floatArray.length / intExact;
        ArrayList arrayList2 = new ArrayList(intExact);
        for (int i = 0; i < intExact; i++) {
            float[] fArr = new float[length];
            System.arraycopy(floatArray, i * length, fArr, 0, length);
            arrayList2.add(fArr);
        }
        return arrayList2;
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer) {
        return new Builder(huggingFaceTokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, ?> map) {
        Builder builder = builder(huggingFaceTokenizer);
        builder.configure(map);
        return builder;
    }
}
