/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.deployment;

import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import io.quarkiverse.langchain4j.deployment.AiServicesProcessor;
import io.quarkiverse.langchain4j.deployment.DotNames;
import io.quarkiverse.langchain4j.deployment.LangChain4jDotNames;
import io.quarkiverse.langchain4j.deployment.TemplateUtil;
import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil;
import io.quarkiverse.langchain4j.runtime.StructuredPromptsRecorder;
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
import io.quarkus.builder.item.BuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.BytecodeTransformerBuildItem;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedClassBuildItem;
import io.quarkus.gizmo.ClassTransformer;
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.FieldInfo;
import org.jboss.jandex.IndexView;
import org.jboss.logging.Logger;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.analysis.Analyzer;
import org.objectweb.asm.tree.analysis.AnalyzerException;
import org.objectweb.asm.tree.analysis.BasicValue;
import org.objectweb.asm.tree.analysis.Interpreter;
import org.objectweb.asm.tree.analysis.SimpleVerifier;

public class PromptProcessor {
    private static final Logger log = Logger.getLogger(AiServicesProcessor.class);
    public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, (String)"put", Object.class, (Class[])new Class[]{Object.class, Object.class});
    public static final MethodDescriptor MAP_PUT_ALL = MethodDescriptor.ofMethod(Map.class, (String)"putAll", Void.TYPE, (Class[])new Class[]{Map.class});
    private static final String STRUCTURED_PROMPT_PROCESSOR_BINARY_NAME = StructuredPromptProcessor.class.getName().replace(".", "/");
    private static final String TO_PROMPT = "toPrompt";
    private static final String TO_PROMPT_DESCRIPTOR = "(Ljava/lang/Object;)Ldev/langchain4j/model/input/Prompt;";

    @BuildStep
    public void nativeSupport(BuildProducer<RuntimeInitializedClassBuildItem> producer) {
        producer.produce((BuildItem)new RuntimeInitializedClassBuildItem("dev.langchain4j.rag.content.injector.DefaultContentInjector"));
    }

    @BuildStep
    @Record(value=ExecutionTime.STATIC_INIT)
    public void structuredPromptSupport(StructuredPromptsRecorder recorder, CombinedIndexBuildItem combinedIndexBuildItem, BuildProducer<BytecodeTransformerBuildItem> transformerProducer) {
        IndexView index = combinedIndexBuildItem.getIndex();
        Collection instances = index.getAnnotations(LangChain4jDotNames.STRUCTURED_PROMPT);
        for (AnnotationInstance instance : instances) {
            AnnotationTarget target = instance.target();
            if (target.kind() != AnnotationTarget.Kind.CLASS) continue;
            CharSequence[] parts = instance.value().asStringArray();
            AnnotationValue delimiterValue = instance.value("delimiter");
            String delimiter = delimiterValue != null ? delimiterValue.asString() : "\n";
            String promptTemplateString = String.join((CharSequence)delimiter, parts);
            ClassInfo annotatedClass = target.asClass();
            if (promptTemplateString.contains(ResponseSchemaUtil.placeholder())) {
                throw new RuntimeException("The %s placeholder is not enabled for the @StructuredPrompt. Found it: %s".formatted(ResponseSchemaUtil.placeholder(), annotatedClass));
            }
            boolean hasNestedParams = PromptProcessor.hasNestedParams(promptTemplateString);
            if (!hasNestedParams) {
                ClassInfo current = annotatedClass;
                while (true) {
                    DotName superName;
                    ClassInfo superClassInfo = DotNames.OBJECT.equals((Object)(superName = current.superName())) ? null : index.getClassByName(superName);
                    transformerProducer.produce((BuildItem)new BytecodeTransformerBuildItem(current.name().toString(), (BiFunction)new StructuredPromptAnnotatedTransformer(current, superClassInfo != null, superName.toString())));
                    if (superClassInfo == null) break;
                    current = superClassInfo;
                }
            }
            recorder.add(annotatedClass.name().toString(), promptTemplateString);
        }
        this.warnForUnsafeUsage(index);
    }

    private static boolean hasNestedParams(String promptTemplateString) {
        return TemplateUtil.parts(promptTemplateString).stream().anyMatch(p -> p.size() > 1);
    }

    private void warnForUnsafeUsage(IndexView index) {
        HashSet<String> candidates = new HashSet<String>();
        for (ClassInfo classInfo : index.getKnownUsers(LangChain4jDotNames.STRUCTURED_PROMPT_PROCESSOR)) {
            String className = classInfo.name().toString();
            if (className.startsWith("io.quarkiverse.langchain4j") || className.startsWith("dev.langchain4j")) continue;
            try {
                InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(className.replace('.', '/') + ".class");
                try {
                    if (is == null) {
                        return;
                    }
                    ClassNode cn = new ClassNode(589824);
                    ClassReader cr = new ClassReader(is);
                    cr.accept((ClassVisitor)cn, 0);
                    for (MethodNode method : cn.methods) {
                        this.analyze(cn, method, candidates);
                    }
                }
                finally {
                    if (is == null) continue;
                    is.close();
                }
            }
            catch (IOException e) {
                throw new UncheckedIOException("Reading bytecode of class '" + className + "' failed", e);
            }
            catch (AnalyzerException e) {
                log.debug((Object)("Unable to analyze bytecode of class '" + className + "'"), (Throwable)e);
            }
        }
        for (String candidate : candidates) {
            ClassInfo classInfo = index.getClassByName(candidate);
            if (classInfo == null || classInfo.hasDeclaredAnnotation(LangChain4jDotNames.STRUCTURED_PROMPT)) continue;
            log.warn((Object)("Class '" + candidate + "' is used in StructuredPromptProcessor but it is not annotated with @StructuredPrompt. This will likely result in an exception being thrown when the prompt is used."));
        }
    }

    private void analyze(ClassNode clazz, MethodNode method, final Set<String> candidates) throws AnalyzerException {
        Type currentClass = Type.getObjectType((String)clazz.name);
        Type currentSuperClass = Type.getObjectType((String)clazz.superName);
        List currentInterfaces = clazz.interfaces.stream().map(Type::getObjectType).collect(Collectors.toList());
        boolean isInterface = (clazz.access & 0x200) == 512;
        SimpleVerifier interpreter = new SimpleVerifier(this, 589824, currentClass, currentSuperClass, currentInterfaces, isInterface){
            final /* synthetic */ PromptProcessor this$0;
            {
                this.this$0 = this$0;
                super(api, currentClass, currentSuperClass, currentClassInterfaces, isInterface);
            }

            public BasicValue naryOperation(AbstractInsnNode insn, List<? extends BasicValue> values) throws AnalyzerException {
                if (insn.getType() == 5) {
                    MethodInsnNode method = (MethodInsnNode)insn;
                    if (STRUCTURED_PROMPT_PROCESSOR_BINARY_NAME.equals(method.owner) && PromptProcessor.TO_PROMPT.equals(method.name) && PromptProcessor.TO_PROMPT_DESCRIPTOR.equals(method.desc)) {
                        BasicValue basicValue = values.get(0);
                        if (basicValue instanceof UnionValue) {
                            UnionValue unionValue = (UnionValue)basicValue;
                            candidates.addAll(unionValue.union.stream().map(Type::getClassName).collect(Collectors.toSet()));
                        } else {
                            candidates.add(basicValue.getType().getClassName());
                        }
                    }
                }
                return super.naryOperation(insn, values);
            }

            public BasicValue newValue(Type type) {
                BasicValue result = super.newValue(type);
                return UnionValue.create(result);
            }

            public BasicValue merge(BasicValue value1, BasicValue value2) {
                BasicValue result = super.merge(value1, value2);
                return UnionValue.create(result, value1, value2);
            }
        };
        Analyzer analyzer = new Analyzer((Interpreter)interpreter);
        analyzer.analyze(clazz.name, method);
    }

    private static class StructuredPromptAnnotatedTransformer
    implements BiFunction<String, ClassVisitor, ClassVisitor> {
        private final ClassInfo annotatedClass;
        private final boolean hasSuperMappable;
        private final String superClassName;

        private StructuredPromptAnnotatedTransformer(ClassInfo annotatedClass, boolean hasSuperMappable, String superClassName) {
            this.annotatedClass = annotatedClass;
            this.hasSuperMappable = hasSuperMappable;
            this.superClassName = superClassName;
        }

        @Override
        public ClassVisitor apply(String s, ClassVisitor classVisitor) {
            ClassTransformer transformer = new ClassTransformer(this.annotatedClass.name().toString());
            transformer.addInterface(Mappable.class);
            MethodCreator mc = transformer.addMethod("obtainFieldValuesMap", Map.class, new Object[0]);
            ResultHandle mapHandle = mc.newInstance(MethodDescriptor.ofConstructor(HashMap.class, (Class[])new Class[0]), new ResultHandle[0]);
            for (FieldInfo field : this.annotatedClass.fields()) {
                short modifiers = field.flags();
                if (Modifier.isStatic(modifiers) || Modifier.isTransient(modifiers)) continue;
                String name = field.name();
                ResultHandle fieldValue = mc.readInstanceField(field, mc.getThis());
                mc.invokeInterfaceMethod(MAP_PUT, mapHandle, new ResultHandle[]{mc.load(name), fieldValue});
            }
            if (this.hasSuperMappable) {
                ResultHandle mapFromSuper = mc.invokeSpecialMethod(MethodDescriptor.ofMethod((Object)this.superClassName, (String)"obtainFieldValuesMap", Map.class, (Object[])new Object[0]), mc.getThis(), new ResultHandle[0]);
                mc.invokeInterfaceMethod(MAP_PUT_ALL, mapFromSuper, new ResultHandle[]{mapHandle});
                mc.returnValue(mapFromSuper);
            } else {
                mc.returnValue(mapHandle);
            }
            return transformer.applyTo(classVisitor);
        }
    }

    private static class UnionValue
    extends BasicValue {
        private final Set<Type> union;

        public static BasicValue create(BasicValue value) {
            if (value == null) {
                return null;
            }
            if (value.getType() == null) {
                return new UnionValue(null, Set.of());
            }
            return new UnionValue(value.getType(), Set.of(value.getType()));
        }

        public static BasicValue create(BasicValue lub, BasicValue value1, BasicValue value2) {
            HashSet<Type> union = new HashSet<Type>();
            union.addAll(((UnionValue)value1).union);
            union.addAll(((UnionValue)value2).union);
            return new UnionValue(lub.getType(), Set.copyOf(union));
        }

        private UnionValue(Type lubType, Set<Type> union) {
            super(lubType);
            this.union = Objects.requireNonNull(union);
        }

        public String toString() {
            return super.toString() + " | union of " + String.valueOf(this.union);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof UnionValue)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            UnionValue that = (UnionValue)((Object)o);
            return Objects.equals(this.union, that.union);
        }

        public int hashCode() {
            return Objects.hash(super.hashCode(), this.union);
        }
    }
}

