package dev.langchain4j.rag.content.aggregator;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.scoring.ScoringModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.ContentMetadata;
import dev.langchain4j.rag.query.Query;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import kotlin.jvm.internal.IntCompanionObject;

/* loaded from: input_file:lib/langchain4j-core-1.1.0.jar:dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.class */
public class ReRankingContentAggregator implements ContentAggregator {
    public static final Function<Map<Query, Collection<List<Content>>>, Query> DEFAULT_QUERY_SELECTOR = map -> {
        if (map.size() > 1) {
            throw Exceptions.illegalArgument("The 'queryToContents' contains %s queries, making the re-ranking ambiguous. Because there are multiple queries, it is unclear which one should be used for re-ranking. Please provide a 'querySelector' in the constructor/builder.", Integer.valueOf(map.size()));
        }
        return (Query) map.keySet().iterator().next();
    };
    private final ScoringModel scoringModel;
    private final Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
    private final Double minScore;
    private final Integer maxResults;

    /* loaded from: input_file:lib/langchain4j-core-1.1.0.jar:dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator$ReRankingContentAggregatorBuilder.class */
    public static class ReRankingContentAggregatorBuilder {
        private ScoringModel scoringModel;
        private Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
        private Double minScore;
        private Integer maxResults;

        ReRankingContentAggregatorBuilder() {
        }

        public ReRankingContentAggregatorBuilder scoringModel(ScoringModel scoringModel) {
            this.scoringModel = scoringModel;
            return this;
        }

        public ReRankingContentAggregatorBuilder querySelector(Function<Map<Query, Collection<List<Content>>>, Query> function) {
            this.querySelector = function;
            return this;
        }

        public ReRankingContentAggregatorBuilder minScore(Double d) {
            this.minScore = d;
            return this;
        }

        public ReRankingContentAggregatorBuilder maxResults(Integer num) {
            this.maxResults = num;
            return this;
        }

        public ReRankingContentAggregator build() {
            return new ReRankingContentAggregator(this.scoringModel, this.querySelector, this.minScore, this.maxResults);
        }
    }

    public ReRankingContentAggregator(ScoringModel scoringModel) {
        this(scoringModel, DEFAULT_QUERY_SELECTOR, null);
    }

    public ReRankingContentAggregator(ScoringModel scoringModel, Function<Map<Query, Collection<List<Content>>>, Query> function, Double d) {
        this(scoringModel, function, d, null);
    }

    public ReRankingContentAggregator(ScoringModel scoringModel, Function<Map<Query, Collection<List<Content>>>, Query> function, Double d, Integer num) {
        this.scoringModel = (ScoringModel) ValidationUtils.ensureNotNull(scoringModel, "scoringModel");
        this.querySelector = (Function) Utils.getOrDefault(function, DEFAULT_QUERY_SELECTOR);
        this.minScore = d;
        this.maxResults = (Integer) Utils.getOrDefault(num, Integer.valueOf(IntCompanionObject.MAX_VALUE));
    }

    public static ReRankingContentAggregatorBuilder builder() {
        return new ReRankingContentAggregatorBuilder();
    }

    @Override // dev.langchain4j.rag.content.aggregator.ContentAggregator
    public List<Content> aggregate(Map<Query, Collection<List<Content>>> map) {
        if (map.isEmpty()) {
            return Collections.emptyList();
        }
        Query apply = this.querySelector.apply(map);
        List<Content> fuse = ReciprocalRankFuser.fuse(fuse(map).values());
        return fuse.isEmpty() ? fuse : reRankAndFilter(fuse, apply);
    }

    protected Map<Query, List<Content>> fuse(Map<Query, Collection<List<Content>>> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Query query : map.keySet()) {
            linkedHashMap.put(query, ReciprocalRankFuser.fuse(map.get(query)));
        }
        return linkedHashMap;
    }

    protected List<Content> reRankAndFilter(List<Content> list, Query query) {
        List<TextSegment> list2 = (List) list.stream().map((v0) -> {
            return v0.textSegment();
        }).collect(Collectors.toList());
        List<Double> content = this.scoringModel.scoreAll(list2, query.text()).content();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list2.size(); i++) {
            hashMap.put(list2.get(i), content.get(i));
        }
        return (List) hashMap.entrySet().stream().filter(entry -> {
            return this.minScore == null || ((Double) entry.getValue()).doubleValue() >= this.minScore.doubleValue();
        }).sorted(Map.Entry.comparingByValue().reversed()).map(entry2 -> {
            return Content.from((TextSegment) entry2.getKey(), Map.of(ContentMetadata.RERANKED_SCORE, entry2.getValue()));
        }).limit(this.maxResults.intValue()).collect(Collectors.toList());
    }
}
