package dev.langchain4j.service;

import dev.langchain4j.Internal;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.memory.ChatMemoryAccess;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolServiceContext;
import dev.langchain4j.service.tool.ToolServiceResult;
import dev.langchain4j.spi.ServiceHelper;
import dev.langchain4j.spi.services.TokenStreamAdapter;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Scanner;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import opennlp.tools.parser.Parse;

/* JADX INFO: Access modifiers changed from: package-private */
@Internal
/* loaded from: input_file:lib/langchain4j-1.0.0-rc1.jar:dev/langchain4j/service/DefaultAiServices.class */
public class DefaultAiServices<T> extends AiServices<T> {
    private final ServiceOutputParser serviceOutputParser;
    private final Collection<TokenStreamAdapter> tokenStreamAdapters;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DefaultAiServices(AiServiceContext aiServiceContext) {
        super(aiServiceContext);
        this.serviceOutputParser = new ServiceOutputParser();
        this.tokenStreamAdapters = ServiceHelper.loadFactories(TokenStreamAdapter.class);
    }

    static void validateParameters(Method method) {
        Parameter[] parameters = method.getParameters();
        if (parameters == null || parameters.length < 2) {
            return;
        }
        for (Parameter parameter : parameters) {
            V v = (V) parameter.getAnnotation(V.class);
            UserMessage userMessage = (UserMessage) parameter.getAnnotation(UserMessage.class);
            MemoryId memoryId = (MemoryId) parameter.getAnnotation(MemoryId.class);
            UserName userName = (UserName) parameter.getAnnotation(UserName.class);
            if (v == null && userMessage == null && memoryId == null && userName == null) {
                throw IllegalConfigurationException.illegalConfiguration("Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId", parameter.getName(), method.getName());
            }
        }
    }

    @Override // dev.langchain4j.service.AiServices
    public T build() {
        performBasicValidation();
        if (!this.context.hasChatMemory() && ChatMemoryAccess.class.isAssignableFrom(this.context.aiServiceClass)) {
            throw IllegalConfigurationException.illegalConfiguration("In order to have a service implementing ChatMemoryAccess, please configure the ChatMemoryProvider on the '%s'.", this.context.aiServiceClass.getName());
        }
        for (Method method : this.context.aiServiceClass.getMethods()) {
            if (method.isAnnotationPresent(Moderate.class) && this.context.moderationModel == null) {
                throw IllegalConfigurationException.illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
            }
            Class<?> returnType = method.getReturnType();
            if (returnType == Void.TYPE) {
                throw IllegalConfigurationException.illegalConfiguration("'%s' is not a supported return type of an AI Service method", returnType.getName());
            }
            if (returnType == Result.class || returnType == List.class || returnType == Set.class) {
                TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
            }
            if (!this.context.hasChatMemory()) {
                for (Parameter parameter : method.getParameters()) {
                    if (parameter.isAnnotationPresent(MemoryId.class)) {
                        throw IllegalConfigurationException.illegalConfiguration("In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.", this.context.aiServiceClass.getName());
                    }
                }
            }
        }
        return (T) Proxy.newProxyInstance(this.context.aiServiceClass.getClassLoader(), new Class[]{this.context.aiServiceClass}, new InvocationHandler() { // from class: dev.langchain4j.service.DefaultAiServices.1
            private final ExecutorService executor = Executors.newCachedThreadPool();

            /* JADX WARN: Multi-variable type inference failed */
            /* JADX WARN: Type inference failed for: r0v53, types: [dev.langchain4j.model.chat.request.DefaultChatRequestParameters$Builder] */
            @Override // java.lang.reflect.InvocationHandler
            public Object invoke(Object obj, Method method2, Object[] objArr) throws Exception {
                List<ChatMessage> arrayList;
                if (method2.getDeclaringClass() == Object.class) {
                    return method2.invoke(this, objArr);
                }
                if (method2.getDeclaringClass() == ChatMemoryAccess.class) {
                    String name = method2.getName();
                    boolean z = -1;
                    switch (name.hashCode()) {
                        case -2143219089:
                            if (name.equals("getChatMemory")) {
                                z = false;
                                break;
                            }
                            break;
                        case -575688798:
                            if (name.equals("evictChatMemory")) {
                                z = true;
                                break;
                            }
                            break;
                    }
                    switch (z) {
                        case false:
                            return DefaultAiServices.this.context.chatMemoryService.getChatMemory(objArr[0]);
                        case true:
                            return Boolean.valueOf(DefaultAiServices.this.context.chatMemoryService.evictChatMemory(objArr[0]) != null);
                        default:
                            throw new UnsupportedOperationException("Unknown method on ChatMemoryAccess class : " + method2.getName());
                    }
                }
                DefaultAiServices.validateParameters(method2);
                Object orElse = DefaultAiServices.findMemoryId(method2, objArr).orElse("default");
                ChatMemory orCreateChatMemory = DefaultAiServices.this.context.hasChatMemory() ? DefaultAiServices.this.context.chatMemoryService.getOrCreateChatMemory(orElse) : null;
                Optional<dev.langchain4j.data.message.SystemMessage> prepareSystemMessage = DefaultAiServices.this.prepareSystemMessage(orElse, method2, objArr);
                dev.langchain4j.data.message.UserMessage prepareUserMessage = DefaultAiServices.prepareUserMessage(method2, objArr);
                AugmentationResult augmentationResult = null;
                if (DefaultAiServices.this.context.retrievalAugmentor != null) {
                    augmentationResult = DefaultAiServices.this.context.retrievalAugmentor.augment(new AugmentationRequest(prepareUserMessage, Metadata.from(prepareUserMessage, orElse, orCreateChatMemory != null ? orCreateChatMemory.messages() : null)));
                    prepareUserMessage = (dev.langchain4j.data.message.UserMessage) augmentationResult.chatMessage();
                }
                Type genericReturnType = method2.getGenericReturnType();
                boolean z2 = genericReturnType == TokenStream.class || canAdaptTokenStreamTo(genericReturnType);
                boolean supportsJsonSchema = supportsJsonSchema();
                Optional<JsonSchema> empty = Optional.empty();
                if (supportsJsonSchema && !z2) {
                    empty = DefaultAiServices.this.serviceOutputParser.jsonSchema(genericReturnType);
                }
                if ((!supportsJsonSchema || empty.isEmpty()) && !z2) {
                    prepareUserMessage = appendOutputFormatInstructions(genericReturnType, prepareUserMessage);
                }
                if (orCreateChatMemory != null) {
                    Objects.requireNonNull(orCreateChatMemory);
                    prepareSystemMessage.ifPresent((v1) -> {
                        r1.add(v1);
                    });
                    orCreateChatMemory.add(prepareUserMessage);
                    arrayList = orCreateChatMemory.messages();
                } else {
                    arrayList = new ArrayList();
                    Objects.requireNonNull(arrayList);
                    prepareSystemMessage.ifPresent((v1) -> {
                        r1.add(v1);
                    });
                    arrayList.add(prepareUserMessage);
                }
                Future<Moderation> triggerModerationIfNeeded = triggerModerationIfNeeded(method2, arrayList);
                ToolServiceContext createContext = DefaultAiServices.this.context.toolService.createContext(orElse, prepareUserMessage);
                if (z2) {
                    AiServiceTokenStream aiServiceTokenStream = new AiServiceTokenStream(AiServiceTokenStreamParameters.builder().messages(arrayList).toolSpecifications(createContext.toolSpecifications()).toolExecutors(createContext.toolExecutors()).retrievedContents(augmentationResult != null ? augmentationResult.contents() : null).context(DefaultAiServices.this.context).memoryId(orElse).build());
                    return genericReturnType == TokenStream.class ? aiServiceTokenStream : adapt(aiServiceTokenStream, genericReturnType);
                }
                ResponseFormat responseFormat = null;
                if (supportsJsonSchema && empty.isPresent()) {
                    responseFormat = ResponseFormat.builder().type(ResponseFormatType.JSON).jsonSchema(empty.get()).build();
                }
                ChatRequestParameters build = ChatRequestParameters.builder().toolSpecifications(createContext.toolSpecifications()).responseFormat(responseFormat).build();
                ChatResponse chat = DefaultAiServices.this.context.chatModel.chat(ChatRequest.builder().messages(arrayList).parameters(build).build());
                AiServices.verifyModerationIfNeeded(triggerModerationIfNeeded);
                ToolServiceResult executeInferenceAndToolsLoop = DefaultAiServices.this.context.toolService.executeInferenceAndToolsLoop(chat, build, arrayList, DefaultAiServices.this.context.chatModel, orCreateChatMemory, orElse, createContext.toolExecutors());
                ChatResponse chatResponse = executeInferenceAndToolsLoop.chatResponse();
                Object parse = DefaultAiServices.this.serviceOutputParser.parse(chatResponse, genericReturnType);
                if (TypeUtils.typeHasRawClass(genericReturnType, Result.class)) {
                    return Result.builder().content(parse).tokenUsage(chatResponse.tokenUsage()).sources(augmentationResult == null ? null : augmentationResult.contents()).finishReason(chatResponse.finishReason()).toolExecutions(executeInferenceAndToolsLoop.toolExecutions()).build();
                }
                return parse;
            }

            private boolean canAdaptTokenStreamTo(Type type) {
                Iterator<TokenStreamAdapter> it = DefaultAiServices.this.tokenStreamAdapters.iterator();
                while (it.hasNext()) {
                    if (it.next().canAdaptTokenStreamTo(type)) {
                        return true;
                    }
                }
                return false;
            }

            private Object adapt(TokenStream tokenStream, Type type) {
                for (TokenStreamAdapter tokenStreamAdapter : DefaultAiServices.this.tokenStreamAdapters) {
                    if (tokenStreamAdapter.canAdaptTokenStreamTo(type)) {
                        return tokenStreamAdapter.adapt(tokenStream);
                    }
                }
                throw new IllegalStateException("Can't find suitable TokenStreamAdapter");
            }

            private boolean supportsJsonSchema() {
                return DefaultAiServices.this.context.chatModel != null && DefaultAiServices.this.context.chatModel.supportedCapabilities().contains(Capability.RESPONSE_FORMAT_JSON_SCHEMA);
            }

            private dev.langchain4j.data.message.UserMessage appendOutputFormatInstructions(Type type, dev.langchain4j.data.message.UserMessage userMessage) {
                String str = userMessage.singleText() + DefaultAiServices.this.serviceOutputParser.outputFormatInstructions(type);
                return Utils.isNotNullOrBlank(userMessage.name()) ? dev.langchain4j.data.message.UserMessage.from(userMessage.name(), str) : dev.langchain4j.data.message.UserMessage.from(str);
            }

            private Future<Moderation> triggerModerationIfNeeded(Method method2, List<ChatMessage> list) {
                if (method2.isAnnotationPresent(Moderate.class)) {
                    return this.executor.submit(() -> {
                        return DefaultAiServices.this.context.moderationModel.moderate(AiServices.removeToolMessages(list)).content();
                    });
                }
                return null;
            }
        });
    }

    private Optional<dev.langchain4j.data.message.SystemMessage> prepareSystemMessage(Object obj, Method method, Object[] objArr) {
        return findSystemMessageTemplate(obj, method).map(str -> {
            return PromptTemplate.from(str).apply(findTemplateVariables(str, method, objArr)).toSystemMessage();
        });
    }

    private Optional<String> findSystemMessageTemplate(Object obj, Method method) {
        SystemMessage systemMessage = (SystemMessage) method.getAnnotation(SystemMessage.class);
        return systemMessage != null ? Optional.of(getTemplate(method, "System", systemMessage.fromResource(), systemMessage.value(), systemMessage.delimiter())) : this.context.systemMessageProvider.apply(obj);
    }

    private static Map<String, Object> findTemplateVariables(String str, Method method, Object[] objArr) {
        Parameter[] parameters = method.getParameters();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < parameters.length; i++) {
            hashMap.put(getVariableName(parameters[i]), objArr[i]);
        }
        if (str.contains("{{it}}") && !hashMap.containsKey("it")) {
            hashMap.put("it", getValueOfVariableIt(parameters, objArr));
        }
        return hashMap;
    }

    private static String getVariableName(Parameter parameter) {
        V v = (V) parameter.getAnnotation(V.class);
        return v != null ? v.value() : parameter.getName();
    }

    private static String getValueOfVariableIt(Parameter[] parameterArr, Object[] objArr) {
        if (parameterArr.length == 1) {
            Parameter parameter = parameterArr[0];
            if (!parameter.isAnnotationPresent(MemoryId.class) && !parameter.isAnnotationPresent(UserMessage.class) && !parameter.isAnnotationPresent(UserName.class) && (!parameter.isAnnotationPresent(V.class) || isAnnotatedWithIt(parameter))) {
                return toString(objArr[0]);
            }
        }
        for (int i = 0; i < parameterArr.length; i++) {
            if (isAnnotatedWithIt(parameterArr[i])) {
                return toString(objArr[i]);
            }
        }
        throw IllegalConfigurationException.illegalConfiguration("Error: cannot find the value of the prompt template variable \"{{it}}\".");
    }

    private static boolean isAnnotatedWithIt(Parameter parameter) {
        V v = (V) parameter.getAnnotation(V.class);
        return v != null && "it".equals(v.value());
    }

    private static dev.langchain4j.data.message.UserMessage prepareUserMessage(Method method, Object[] objArr) {
        String userMessageTemplate = getUserMessageTemplate(method, objArr);
        Prompt apply = PromptTemplate.from(userMessageTemplate).apply(findTemplateVariables(userMessageTemplate, method, objArr));
        Optional<U> map = findUserName(method.getParameters(), objArr).map(str -> {
            return dev.langchain4j.data.message.UserMessage.from(str, apply.text());
        });
        Objects.requireNonNull(apply);
        return (dev.langchain4j.data.message.UserMessage) map.orElseGet(apply::toUserMessage);
    }

    private static String getUserMessageTemplate(Method method, Object[] objArr) {
        Optional<String> findUserMessageTemplateFromMethodAnnotation = findUserMessageTemplateFromMethodAnnotation(method);
        Optional<String> findUserMessageTemplateFromAnnotatedParameter = findUserMessageTemplateFromAnnotatedParameter(method.getParameters(), objArr);
        if (findUserMessageTemplateFromMethodAnnotation.isPresent() && findUserMessageTemplateFromAnnotatedParameter.isPresent()) {
            throw IllegalConfigurationException.illegalConfiguration("Error: The method '%s' has multiple @UserMessage annotations. Please use only one.", method.getName());
        }
        if (findUserMessageTemplateFromMethodAnnotation.isPresent()) {
            return findUserMessageTemplateFromMethodAnnotation.get();
        }
        if (findUserMessageTemplateFromAnnotatedParameter.isPresent()) {
            return findUserMessageTemplateFromAnnotatedParameter.get();
        }
        Optional<String> findUserMessageTemplateFromTheOnlyArgument = findUserMessageTemplateFromTheOnlyArgument(method.getParameters(), objArr);
        if (findUserMessageTemplateFromTheOnlyArgument.isPresent()) {
            return findUserMessageTemplateFromTheOnlyArgument.get();
        }
        throw IllegalConfigurationException.illegalConfiguration("Error: The method '%s' does not have a user message defined.", method.getName());
    }

    private static Optional<String> findUserMessageTemplateFromMethodAnnotation(Method method) {
        return Optional.ofNullable((UserMessage) method.getAnnotation(UserMessage.class)).map(userMessage -> {
            return getTemplate(method, "User", userMessage.fromResource(), userMessage.value(), userMessage.delimiter());
        });
    }

    private static Optional<String> findUserMessageTemplateFromAnnotatedParameter(Parameter[] parameterArr, Object[] objArr) {
        for (int i = 0; i < parameterArr.length; i++) {
            if (parameterArr[i].isAnnotationPresent(UserMessage.class)) {
                return Optional.of(toString(objArr[i]));
            }
        }
        return Optional.empty();
    }

    private static Optional<String> findUserMessageTemplateFromTheOnlyArgument(Parameter[] parameterArr, Object[] objArr) {
        return (parameterArr != null && parameterArr.length == 1 && parameterArr[0].getAnnotations().length == 0) ? Optional.of(toString(objArr[0])) : Optional.empty();
    }

    private static Optional<String> findUserName(Parameter[] parameterArr, Object[] objArr) {
        for (int i = 0; i < parameterArr.length; i++) {
            if (parameterArr[i].isAnnotationPresent(UserName.class)) {
                return Optional.of(objArr[i].toString());
            }
        }
        return Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String getTemplate(Method method, String str, String str2, String[] strArr, String str3) {
        String join;
        if (str2.trim().isEmpty()) {
            join = String.join(str3, strArr);
        } else {
            join = getResourceText(method.getDeclaringClass(), str2);
            if (join == null) {
                throw IllegalConfigurationException.illegalConfiguration("@%sMessage's resource '%s' not found", str, str2);
            }
        }
        if (join.trim().isEmpty()) {
            throw IllegalConfigurationException.illegalConfiguration("@%sMessage's template cannot be empty", str);
        }
        return join;
    }

    private static String getResourceText(Class<?> cls, String str) {
        InputStream resourceAsStream = cls.getResourceAsStream(str);
        if (resourceAsStream == null) {
            resourceAsStream = cls.getResourceAsStream("/" + str);
        }
        return getText(resourceAsStream);
    }

    private static String getText(InputStream inputStream) {
        if (inputStream == null) {
            return null;
        }
        Scanner scanner = new Scanner(inputStream);
        try {
            Scanner useDelimiter = scanner.useDelimiter("\\A");
            try {
                String next = useDelimiter.hasNext() ? useDelimiter.next() : "";
                if (useDelimiter != null) {
                    useDelimiter.close();
                }
                scanner.close();
                return next;
            } catch (Throwable th) {
                if (useDelimiter != null) {
                    try {
                        useDelimiter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            try {
                scanner.close();
            } catch (Throwable th4) {
                th3.addSuppressed(th4);
            }
            throw th3;
        }
    }

    private static Optional<Object> findMemoryId(Method method, Object[] objArr) {
        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameters.length; i++) {
            if (parameters[i].isAnnotationPresent(MemoryId.class)) {
                Object obj = objArr[i];
                if (obj == null) {
                    throw Exceptions.illegalArgument("The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null", parameters[i].getName(), method.getName());
                }
                return Optional.of(obj);
            }
        }
        return Optional.empty();
    }

    private static String toString(Object obj) {
        return obj.getClass().isArray() ? arrayToString(obj) : obj.getClass().isAnnotationPresent(StructuredPrompt.class) ? StructuredPromptProcessor.toPrompt(obj).text() : obj.toString();
    }

    private static String arrayToString(Object obj) {
        StringBuilder sb = new StringBuilder(Parse.BRACKET_LSB);
        int length = Array.getLength(obj);
        for (int i = 0; i < length; i++) {
            sb.append(toString(Array.get(obj, i)));
            if (i < length - 1) {
                sb.append(", ");
            }
        }
        sb.append(Parse.BRACKET_RSB);
        return sb.toString();
    }
}
