/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.rag.content.util;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.HashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class EmbeddingMetadataUtils {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingMetadataUtils.class);
    private static final String DOCUMENT_EMBEDDING_KEY = "embedding";
    private static final String QUERY_EMBEDDING_KEY = "queryEmbedding";

    private EmbeddingMetadataUtils() {
    }

    public static TextSegment enrichSegmentWithEmbeddings(TextSegment segment, Embedding queryEmbedding, Embedding documentEmbedding) {
        HashMap<String, String> metadata = new HashMap<String, String>();
        if (segment.metadata() != null) {
            metadata.putAll(segment.metadata().toMap());
        }
        if (documentEmbedding != null) {
            metadata.put(DOCUMENT_EMBEDDING_KEY, EmbeddingMetadataUtils.embeddingToBase64(documentEmbedding));
        }
        if (queryEmbedding != null) {
            metadata.put(QUERY_EMBEDDING_KEY, EmbeddingMetadataUtils.embeddingToBase64(queryEmbedding));
        }
        return TextSegment.from((String)segment.text(), (Metadata)Metadata.from(metadata));
    }

    public static Embedding extractDocumentEmbedding(TextSegment segment) {
        return EmbeddingMetadataUtils.extractEmbedding(segment, DOCUMENT_EMBEDDING_KEY);
    }

    public static Embedding extractQueryEmbedding(TextSegment segment) {
        return EmbeddingMetadataUtils.extractEmbedding(segment, QUERY_EMBEDDING_KEY);
    }

    private static Embedding extractEmbedding(TextSegment segment, String key) {
        if (segment.metadata() == null) {
            return null;
        }
        Object stored = segment.metadata().toMap().get(key);
        if (stored instanceof String) {
            String base64 = (String)stored;
            return EmbeddingMetadataUtils.base64ToEmbedding(base64);
        }
        return null;
    }

    private static String embeddingToBase64(Embedding embedding) {
        log.warn("Document embedding stored as base64 string due to metadata type constraints. See dev.langchain4j.data.document.Metadata for supported types");
        if (embedding == null) {
            return null;
        }
        float[] vector = embedding.vector();
        if (vector == null) {
            return null;
        }
        ByteBuffer buffer = ByteBuffer.allocate(4 * vector.length);
        for (float v : vector) {
            buffer.putFloat(v);
        }
        return Base64.getEncoder().encodeToString(buffer.array());
    }

    private static Embedding base64ToEmbedding(String base64) {
        log.warn("Converting base64 string back to embedding due to metadata type limitations.See dev.langchain4j.data.document.Metadata for supported types");
        if (base64 == null) {
            return null;
        }
        byte[] bytes = Base64.getDecoder().decode(base64);
        ByteBuffer buffer = ByteBuffer.wrap(bytes);
        float[] vector = new float[bytes.length / 4];
        for (int i = 0; i < vector.length; ++i) {
            vector[i] = buffer.getFloat();
        }
        return new Embedding(vector);
    }

    public static boolean hasDocumentEmbedding(TextSegment segment) {
        return EmbeddingMetadataUtils.extractDocumentEmbedding(segment) != null;
    }

    public static boolean hasQueryEmbedding(TextSegment segment) {
        return EmbeddingMetadataUtils.extractQueryEmbedding(segment) != null;
    }
}

