package dev.langchain4j.model.embedding.onnx;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.embedding.onnx.OnnxBertBiEncoder;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/* loaded from: input_file:lib/langchain4j-embeddings-1.1.0-beta7.jar:dev/langchain4j/model/embedding/onnx/AbstractInProcessEmbeddingModel.class */
public abstract class AbstractInProcessEmbeddingModel extends DimensionAwareEmbeddingModel {
    private final Executor executor;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractInProcessEmbeddingModel(Executor executor) {
        this.executor = (Executor) Utils.getOrDefault(executor, (Supplier<Executor>) this::createDefaultExecutor);
    }

    private Executor createDefaultExecutor() {
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(availableProcessors, availableProcessors, 1L, TimeUnit.SECONDS, new LinkedBlockingQueue());
        threadPoolExecutor.allowCoreThreadTimeOut(true);
        return threadPoolExecutor;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static OnnxBertBiEncoder loadFromJar(String str, String str2, PoolingMode poolingMode) {
        return new OnnxBertBiEncoder(Thread.currentThread().getContextClassLoader().getResourceAsStream(str), Thread.currentThread().getContextClassLoader().getResourceAsStream(str2), poolingMode);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static OnnxBertBiEncoder loadFromFileSystem(Path path, Path path2, PoolingMode poolingMode) {
        return new OnnxBertBiEncoder(path, path2, poolingMode);
    }

    protected abstract OnnxBertBiEncoder model();

    @Override // dev.langchain4j.model.embedding.EmbeddingModel
    public Response<List<Embedding>> embedAll(List<TextSegment> list) {
        ValidationUtils.ensureNotEmpty(list, "segments");
        return list.size() == 1 ? embedInTheSameThread(list.get(0)) : parallelizeEmbedding(list);
    }

    private Response<List<Embedding>> embedInTheSameThread(TextSegment textSegment) {
        OnnxBertBiEncoder.EmbeddingAndTokenCount embed = model().embed(textSegment.text());
        return Response.from(Collections.singletonList(Embedding.from(embed.embedding)), new TokenUsage(Integer.valueOf(embed.tokenCount - 2)));
    }

    private Response<List<Embedding>> parallelizeEmbedding(List<TextSegment> list) {
        List list2 = (List) list.stream().map(textSegment -> {
            return CompletableFuture.supplyAsync(() -> {
                return model().embed(textSegment.text());
            }, this.executor);
        }).collect(Collectors.toList());
        int i = 0;
        ArrayList arrayList = new ArrayList();
        Iterator it = list2.iterator();
        while (it.hasNext()) {
            try {
                OnnxBertBiEncoder.EmbeddingAndTokenCount embeddingAndTokenCount = (OnnxBertBiEncoder.EmbeddingAndTokenCount) ((CompletableFuture) it.next()).get();
                arrayList.add(Embedding.from(embeddingAndTokenCount.embedding));
                i += embeddingAndTokenCount.tokenCount - 2;
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        }
        return Response.from(arrayList, new TokenUsage(Integer.valueOf(i)));
    }
}
