/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.runtime.listeners;

import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.TokenUsage;
import io.micrometer.core.instrument.Clock;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Meter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.quarkiverse.langchain4j.cost.Cost;
import io.quarkiverse.langchain4j.cost.CostEstimatorService;
import io.quarkiverse.langchain4j.runtime.ContextLocals;
import java.util.concurrent.TimeUnit;
import org.jboss.logging.Logger;

public class MetricsChatModelListener
implements ChatModelListener {
    private static final Logger log = Logger.getLogger(MetricsChatModelListener.class);
    public static final String START_TIME_KEY_NAME = "startTime";
    private final CostEstimatorService costEstimatorService;
    private final Meter.MeterProvider<Counter> inputTokenUsage;
    private final Meter.MeterProvider<Counter> outputTokenUsage;
    private final Meter.MeterProvider<Timer> duration;
    private final Meter.MeterProvider<Counter> estimatedCost;

    public MetricsChatModelListener(CostEstimatorService costEstimatorService) {
        this.costEstimatorService = costEstimatorService;
        this.inputTokenUsage = Counter.builder((String)"gen_ai.client.token.usage").description("Measures number of input tokens used").tag("gen_ai.operation.name", "completion").tag("gen_ai.token.type", "input").withRegistry((MeterRegistry)Metrics.globalRegistry);
        this.outputTokenUsage = Counter.builder((String)"gen_ai.client.token.usage").description("Measures number of output tokens used").tag("gen_ai.operation.name", "completion").tag("gen_ai.token.type", "output").withRegistry((MeterRegistry)Metrics.globalRegistry);
        this.duration = Timer.builder((String)"gen_ai.client.operation.duration").description("GenAI operation duration").tag("gen_ai.operation.name", "completion").withRegistry((MeterRegistry)Metrics.globalRegistry);
        this.estimatedCost = Counter.builder((String)"gen_ai.client.estimated_cost").description("Estimated cost of the request").tag("gen_ai.operation.name", "completion").tag("gen_ai.token.type", "output").withRegistry((MeterRegistry)Metrics.globalRegistry);
    }

    public void onRequest(ChatModelRequestContext requestContext) {
        long startTime = Clock.SYSTEM.monotonicTime();
        requestContext.attributes().put(START_TIME_KEY_NAME, startTime);
    }

    public void onResponse(ChatModelResponseContext responseContext) {
        long endTime = Clock.SYSTEM.monotonicTime();
        ChatRequest request = responseContext.chatRequest();
        ChatResponse response = responseContext.chatResponse();
        Tags tags = Tags.empty();
        if (request.parameters().modelName() != null) {
            tags = tags.and("gen_ai.request.model", request.parameters().modelName());
        }
        if (response.metadata().modelName() != null) {
            tags = tags.and("gen_ai.response.model", response.metadata().modelName());
        }
        if (ContextLocals.duplicatedContextActive()) {
            String aiServiceMethodName;
            String aiServiceClassName = (String)ContextLocals.get("aiservice.classname");
            if (aiServiceClassName != null) {
                tags = tags.and("ai_service.class_name", aiServiceClassName);
            }
            if ((aiServiceMethodName = (String)ContextLocals.get("aiservice.methodname")) != null) {
                tags = tags.and("ai_service.method_name", aiServiceMethodName);
            }
        }
        this.recordTokenUsage(responseContext, tags);
        this.recordDuration(responseContext, endTime, tags);
    }

    public void onError(ChatModelErrorContext errorContext) {
        long endTime = Clock.SYSTEM.monotonicTime();
        Long startTime = (Long)errorContext.attributes().get(START_TIME_KEY_NAME);
        if (startTime == null) {
            log.warn((Object)"No start time found in response");
            return;
        }
        Tags tags = Tags.of((String)"gen_ai.request.model", (String)errorContext.chatRequest().parameters().modelName());
        if (errorContext.error() != null) {
            tags = tags.and("error.type", errorContext.error().getMessage());
        }
        ((Timer)this.duration.withTags((Iterable)tags)).record(endTime - startTime, TimeUnit.NANOSECONDS);
    }

    private void recordTokenUsage(ChatModelResponseContext responseContext, Tags tags) {
        Cost costEstimate;
        Integer outputTokenCount;
        TokenUsage tokenUsage = responseContext.chatResponse().tokenUsage();
        if (tokenUsage == null) {
            return;
        }
        Integer inputTokenCount = tokenUsage.inputTokenCount();
        if (inputTokenCount != null) {
            ((Counter)this.inputTokenUsage.withTags((Iterable)tags)).increment((double)inputTokenCount.intValue());
        }
        if ((outputTokenCount = tokenUsage.outputTokenCount()) != null) {
            ((Counter)this.outputTokenUsage.withTags((Iterable)tags)).increment((double)outputTokenCount.intValue());
        }
        if (inputTokenCount != null && outputTokenCount != null && (costEstimate = this.costEstimatorService.estimate(responseContext)) != null) {
            ((Counter)this.estimatedCost.withTags((Iterable)tags.and("currency", costEstimate.currencyCode()))).increment(costEstimate.number().doubleValue());
        }
    }

    private void recordDuration(ChatModelResponseContext responseContext, long endTime, Tags tags) {
        Long startTime = (Long)responseContext.attributes().get(START_TIME_KEY_NAME);
        if (startTime == null) {
            log.warn((Object)"No start time found in response");
            return;
        }
        ((Timer)this.duration.withTags((Iterable)tags)).record(endTime - startTime, TimeUnit.NANOSECONDS);
    }
}

