/*
 * Decompiled with CFR 0.152.
 */
package com.arcadedb.integration.importer.vector;

import com.arcadedb.database.Database;
import com.arcadedb.database.DatabaseFactory;
import com.arcadedb.database.DatabaseInternal;
import com.arcadedb.index.vector.HnswVectorIndexRAM;
import com.arcadedb.index.vector.VectorUtils;
import com.arcadedb.index.vector.distance.DistanceFunctionFactory;
import com.arcadedb.integration.importer.ConsoleLogger;
import com.arcadedb.integration.importer.ImporterContext;
import com.arcadedb.integration.importer.ImporterSettings;
import com.arcadedb.integration.importer.vector.TextFloatsEmbedding;
import com.arcadedb.schema.Type;
import com.arcadedb.utility.CodeUtils;
import com.arcadedb.utility.DateUtils;
import com.github.jelmerk.knn.DistanceFunction;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TextEmbeddingsImporter {
    private final InputStream inputStream;
    private final ImporterSettings settings;
    private final ConsoleLogger logger;
    private int m;
    private int ef;
    private int efConstruction;
    private boolean normalizeVectors = false;
    private String databasePath;
    private boolean overwriteDatabase = false;
    private long errors = 0L;
    private long warnings = 0L;
    private DatabaseFactory factory;
    private Database database;
    private long beginTime;
    private boolean error = false;
    private ImporterContext context = new ImporterContext();
    private String vectorTypeName;
    private String distanceFunctionName;
    private String vectorPropertyName;
    private String idPropertyName = "name";
    private String deletedPropertyName = "deleted";
    private volatile long embeddingsParsed = 0L;
    private volatile long indexedEmbedding = 0L;
    private volatile long verticesCreated = 0L;
    private volatile long verticesConnected = 0L;

    public TextEmbeddingsImporter(DatabaseInternal database, InputStream inputStream, ImporterSettings settings) throws ClassNotFoundException {
        this.settings = settings;
        this.database = database;
        this.databasePath = database.getDatabasePath();
        this.inputStream = inputStream;
        this.logger = new ConsoleLogger(settings.verboseLevel);
        this.distanceFunctionName = settings.getValue("distanceFunction", "InnerProduct");
        this.distanceFunctionName = Character.toUpperCase(this.distanceFunctionName.charAt(0)) + this.distanceFunctionName.substring(1).toLowerCase(Locale.ENGLISH);
        this.vectorTypeName = settings.getValue("vectorType", "Float");
        this.vectorTypeName = Character.toUpperCase(this.vectorTypeName.charAt(0)) + this.vectorTypeName.substring(1).toLowerCase(Locale.ENGLISH);
        if (settings.options.containsKey("vectorProperty")) {
            this.vectorPropertyName = settings.getValue("vectorProperty", null);
        }
        if (settings.options.containsKey("idProperty")) {
            this.idPropertyName = settings.getValue("idProperty", null);
        }
        if (settings.options.containsKey("deletedProperty")) {
            this.deletedPropertyName = settings.getValue("deletedProperty", null);
        }
        this.m = settings.getIntValue("m", 16);
        this.ef = settings.getIntValue("ef", 256);
        this.efConstruction = settings.getIntValue("efConstruction", 256);
        if (settings.options.containsKey("normalizeVectors")) {
            this.normalizeVectors = Boolean.parseBoolean(settings.getValue("normalizeVectors", null));
        }
    }

    public Database run() throws IOException, ClassNotFoundException, InterruptedException {
        if (!this.createDatabase()) {
            return null;
        }
        DistanceFunction distanceFunction = DistanceFunctionFactory.getImplementationByName((String)(this.vectorTypeName + this.distanceFunctionName));
        this.beginTime = System.currentTimeMillis();
        List<TextFloatsEmbedding> texts = this.loadFromFile();
        if (this.settings.documentsSkipEntries != null) {
            int i = 0;
            while ((long)i < this.settings.documentsSkipEntries) {
                texts.removeFirst();
                ++i;
            }
        }
        if (!texts.isEmpty()) {
            int dimensions = texts.get(1).dimensions();
            this.logger.logLine(2, "- Parsed %,d embeddings with %,d dimensions in RAM", texts.size(), dimensions);
            HnswVectorIndexRAM hnswIndex = HnswVectorIndexRAM.newBuilder((int)dimensions, (DistanceFunction)distanceFunction, (int)texts.size()).withM(this.m).withEf(this.ef).withEfConstruction(this.efConstruction).build();
            hnswIndex.addAll(texts, Runtime.getRuntime().availableProcessors(), (workDone, max) -> ++this.indexedEmbedding, 1);
            Type vectorPropertyType = switch (this.vectorTypeName) {
                case "Short" -> Type.ARRAY_OF_SHORTS;
                case "Integer" -> Type.ARRAY_OF_INTEGERS;
                case "Long" -> Type.ARRAY_OF_LONGS;
                case "Float" -> Type.ARRAY_OF_FLOATS;
                case "Double" -> Type.ARRAY_OF_DOUBLES;
                default -> throw new IllegalArgumentException("Type '" + this.vectorTypeName + "' not supported");
            };
            hnswIndex.createPersistentIndex(this.database).withVertexType(this.settings.vertexTypeName).withEdgeType(this.settings.edgeTypeName).withVectorProperty(this.vectorPropertyName, vectorPropertyType).withIdProperty(this.idPropertyName).withDeletedProperty(this.deletedPropertyName).withVertexCreationCallback((record, item, total) -> ++this.verticesCreated).withCallback((record, total) -> ++this.verticesConnected).withBatchSize(1000).create();
        }
        this.logger.logLine(1, "***************************************************************************************************", new Object[0]);
        this.logger.logLine(1, "Import of Text Embeddings database completed in %s with %,d errors and %,d warnings.", DateUtils.formatElapsed((long)(System.currentTimeMillis() - this.beginTime)), this.errors, this.warnings);
        this.logger.logLine(1, "\nSUMMARY\n", new Object[0]);
        this.logger.logLine(1, "- Embeddings.................................: %,d", texts.size());
        this.logger.logLine(1, "***************************************************************************************************", new Object[0]);
        this.logger.logLine(1, "", new Object[0]);
        if (this.database != null) {
            this.logger.logLine(1, "NOTES:", new Object[0]);
            this.logger.logLine(1, "- you can find your new ArcadeDB database in '" + this.database.getDatabasePath() + "'", new Object[0]);
        }
        return this.database;
    }

    public void printProgress() {
        float progressPerc = 0.0f;
        if (this.verticesConnected > 0L) {
            progressPerc = 40.0f + (float)this.verticesConnected * 60.0f / (float)this.embeddingsParsed;
        } else if (this.verticesCreated > 0L) {
            progressPerc = 10.0f + (float)this.verticesCreated * 30.0f / (float)this.embeddingsParsed;
        } else if (this.indexedEmbedding > 0L) {
            progressPerc = (float)this.indexedEmbedding * 10.0f / (float)this.embeddingsParsed;
        }
        Object result = "- %.2f%%".formatted(Float.valueOf(progressPerc));
        if (this.embeddingsParsed > 0L) {
            result = (String)result + " - %,d embeddings parsed".formatted(this.embeddingsParsed);
        }
        if (this.indexedEmbedding > 0L) {
            result = (String)result + " - %,d embeddings indexed".formatted(this.indexedEmbedding);
        }
        if (this.verticesCreated > 0L) {
            result = (String)result + " - %,d vertices created".formatted(this.verticesCreated);
        }
        if (this.verticesConnected > 0L) {
            result = (String)result + " - %,d vertices connected".formatted(this.verticesConnected);
        }
        result = (String)result + " (elapsed " + DateUtils.formatElapsed((long)(System.currentTimeMillis() - this.beginTime)) + ")";
        this.logger.logLine(2, (String)result, new Object[0]);
    }

    private boolean createDatabase() {
        if (this.database == null) {
            this.factory = new DatabaseFactory(this.databasePath);
            if (this.factory.exists()) {
                if (!this.overwriteDatabase) {
                    this.logger.errorLine("Database already exists on path '%s'", this.databasePath);
                    ++this.errors;
                    return false;
                }
                this.database = this.factory.open();
                this.logger.errorLine("Found existent database at '%s', dropping it and recreate a new one", this.databasePath);
                this.database.drop();
            }
            this.database = this.factory.create();
        }
        return true;
    }

    public boolean isError() {
        return this.error;
    }

    public ImporterContext getContext() {
        return this.context;
    }

    public TextEmbeddingsImporter setContext(ImporterContext context) {
        this.context = context;
        return this;
    }

    private List<TextFloatsEmbedding> loadFromFile() throws IOException {
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(this.inputStream));){
            Stream<String> parser = reader.lines();
            if (this.settings.parsingLimitEntries > 0L) {
                parser = parser.limit(this.settings.parsingLimitEntries);
            }
            AtomicInteger vectorSize = new AtomicInteger(301);
            List<TextFloatsEmbedding> list = parser.map(line -> {
                ++this.embeddingsParsed;
                List tokens = CodeUtils.split((String)line, (char)' ', (int)-1, (int)vectorSize.get());
                String word = (String)tokens.getFirst();
                float[] vector = new float[tokens.size() - 1];
                for (int i = 1; i < tokens.size() - 1; ++i) {
                    vector[i] = Float.parseFloat((String)tokens.get(i));
                }
                vectorSize.set(vector.length);
                if (this.normalizeVectors) {
                    vector = VectorUtils.normalize((float[])vector);
                }
                return new TextFloatsEmbedding(word, vector);
            }).collect(Collectors.toList());
            return list;
        }
    }
}

