package dev.langchain4j.mcp;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.mcp.client.McpClient;
import dev.langchain4j.service.IllegalConfigurationException;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiPredicate;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:lib/langchain4j-mcp-1.1.0-beta7.jar:dev/langchain4j/mcp/McpToolProvider.class */
public class McpToolProvider implements ToolProvider {
    private final CopyOnWriteArrayList<McpClient> mcpClients;
    private final boolean failIfOneServerFails;
    private final AtomicReference<BiPredicate<McpClient, ToolSpecification>> mcpToolsFilter;
    private final Function<ToolExecutor, ToolExecutor> toolWrapper;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) McpToolProvider.class);

    /* loaded from: input_file:lib/langchain4j-mcp-1.1.0-beta7.jar:dev/langchain4j/mcp/McpToolProvider$Builder.class */
    public static class Builder {
        private List<McpClient> mcpClients;
        private Boolean failIfOneServerFails;
        private BiPredicate<McpClient, ToolSpecification> mcpToolsFilter = (mcpClient, toolSpecification) -> {
            return true;
        };
        private Function<ToolExecutor, ToolExecutor> toolWrapper = Function.identity();

        public Builder mcpClients(List<McpClient> list) {
            this.mcpClients = list;
            return this;
        }

        public Builder mcpClients(McpClient... mcpClientArr) {
            return mcpClients(Arrays.asList(mcpClientArr));
        }

        public Builder filter(BiPredicate<McpClient, ToolSpecification> biPredicate) {
            this.mcpToolsFilter = this.mcpToolsFilter.and(biPredicate);
            return this;
        }

        public Builder filterToolNames(String... strArr) {
            return filter(new ToolsNameFilter(strArr));
        }

        public Builder failIfOneServerFails(boolean z) {
            this.failIfOneServerFails = Boolean.valueOf(z);
            return this;
        }

        public Builder toolWrapper(Function<ToolExecutor, ToolExecutor> function) {
            this.toolWrapper = function;
            return this;
        }

        public McpToolProvider build() {
            return new McpToolProvider(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/langchain4j-mcp-1.1.0-beta7.jar:dev/langchain4j/mcp/McpToolProvider$DefaultToolExecutor.class */
    public static class DefaultToolExecutor implements ToolExecutor {
        private final McpClient mcpClient;

        public DefaultToolExecutor(McpClient mcpClient) {
            this.mcpClient = mcpClient;
        }

        @Override // dev.langchain4j.service.tool.ToolExecutor
        public String execute(ToolExecutionRequest toolExecutionRequest, Object obj) {
            return this.mcpClient.executeTool(toolExecutionRequest);
        }
    }

    /* loaded from: input_file:lib/langchain4j-mcp-1.1.0-beta7.jar:dev/langchain4j/mcp/McpToolProvider$ToolsNameFilter.class */
    private static class ToolsNameFilter implements BiPredicate<McpClient, ToolSpecification> {
        private final List<String> toolNames;

        private ToolsNameFilter(String... strArr) {
            this((List<String>) Arrays.asList(strArr));
        }

        private ToolsNameFilter(List<String> list) {
            this.toolNames = list;
        }

        @Override // java.util.function.BiPredicate
        public boolean test(McpClient mcpClient, ToolSpecification toolSpecification) {
            return this.toolNames.stream().anyMatch(str -> {
                return str.equals(toolSpecification.name());
            });
        }
    }

    private McpToolProvider(Builder builder) {
        this(builder.mcpClients, ((Boolean) Utils.getOrDefault((boolean) builder.failIfOneServerFails, false)).booleanValue(), builder.mcpToolsFilter, builder.toolWrapper);
    }

    protected McpToolProvider(List<McpClient> list, boolean z, BiPredicate<McpClient, ToolSpecification> biPredicate) {
        this((List) Objects.requireNonNull(list), z, biPredicate, Function.identity());
    }

    protected McpToolProvider(List<McpClient> list, boolean z, BiPredicate<McpClient, ToolSpecification> biPredicate, Function<ToolExecutor, ToolExecutor> function) {
        this.mcpClients = new CopyOnWriteArrayList<>(list);
        this.failIfOneServerFails = z;
        this.mcpToolsFilter = new AtomicReference<>(biPredicate);
        this.toolWrapper = function;
    }

    public void addMcpClient(McpClient mcpClient) {
        Objects.requireNonNull(mcpClient);
        this.mcpClients.add(mcpClient);
    }

    public void removeMcpClient(McpClient mcpClient) {
        this.mcpClients.remove(mcpClient);
    }

    public void addFilter(BiPredicate<McpClient, ToolSpecification> biPredicate) {
        Objects.requireNonNull(biPredicate);
        BiPredicate<McpClient, ToolSpecification> biPredicate2 = this.mcpToolsFilter.get();
        while (true) {
            BiPredicate<McpClient, ToolSpecification> biPredicate3 = biPredicate2;
            if (this.mcpToolsFilter.compareAndSet(biPredicate3, biPredicate3.and(biPredicate))) {
                return;
            } else {
                biPredicate2 = this.mcpToolsFilter.get();
            }
        }
    }

    public void setFilter(BiPredicate<McpClient, ToolSpecification> biPredicate) {
        Objects.requireNonNull(biPredicate);
        BiPredicate<McpClient, ToolSpecification> biPredicate2 = this.mcpToolsFilter.get();
        while (!this.mcpToolsFilter.compareAndSet(biPredicate2, biPredicate)) {
            biPredicate2 = this.mcpToolsFilter.get();
        }
    }

    public void resetFilters() {
        setFilter((mcpClient, toolSpecification) -> {
            return true;
        });
    }

    @Override // dev.langchain4j.service.tool.ToolProvider
    public ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest) {
        return provideTools(toolProviderRequest, this.mcpToolsFilter.get());
    }

    protected ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest, BiPredicate<McpClient, ToolSpecification> biPredicate) {
        ToolProviderResult.Builder builder = ToolProviderResult.builder();
        Iterator<McpClient> it = this.mcpClients.iterator();
        while (it.hasNext()) {
            McpClient next = it.next();
            DefaultToolExecutor defaultToolExecutor = new DefaultToolExecutor(next);
            try {
                next.listTools().stream().filter(toolSpecification -> {
                    return biPredicate.test(next, toolSpecification);
                }).forEach(toolSpecification2 -> {
                    builder.add(toolSpecification2, this.toolWrapper.apply(defaultToolExecutor));
                });
            } catch (IllegalConfigurationException e) {
                throw e;
            } catch (Exception e2) {
                if (this.failIfOneServerFails) {
                    throw new RuntimeException("Failed to retrieve tools from MCP server", e2);
                }
                log.warn("Failed to retrieve tools from MCP server", (Throwable) e2);
            }
        }
        return builder.build();
    }

    public static Builder builder() {
        return new Builder();
    }
}
