/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.model.openai;

import com.openai.models.embeddings.CreateEmbeddingResponse;
import com.openai.models.embeddings.EmbeddingCreateParams;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.ConfigOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.model.openai.AbstractOpenAIModelFunction;
import org.apache.flink.table.data.GenericArrayData;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.FloatType;
import org.apache.flink.table.types.logical.LogicalType;

public class OpenAIEmbeddingModelFunction
extends AbstractOpenAIModelFunction {
    private static final long serialVersionUID = 1L;
    public static final String ENDPOINT_SUFFIX = "embeddings";
    public static final ConfigOption<Long> DIMENSION = ConfigOptions.key((String)"dimension").longType().noDefaultValue().withDescription("Dimension of the embedding vector.");
    private final String model;
    @Nullable
    private final Long dimensions;

    public OpenAIEmbeddingModelFunction(ModelProviderFactory.Context factoryContext, ReadableConfig config) {
        super(factoryContext, config);
        this.model = (String)config.get(MODEL);
        this.dimensions = (Long)config.get(DIMENSION);
        this.validateSingleColumnSchema(factoryContext.getCatalogModel().getResolvedOutputSchema(), (LogicalType)new ArrayType((LogicalType)new FloatType()), "output");
    }

    @Override
    protected String getEndpointSuffix() {
        return ENDPOINT_SUFFIX;
    }

    public CompletableFuture<Collection<RowData>> asyncPredict(RowData rowData) {
        EmbeddingCreateParams.Builder builder = EmbeddingCreateParams.builder();
        builder.model(this.model);
        builder.input(rowData.getString(0).toString());
        builder.encodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT);
        if (this.dimensions != null) {
            builder.dimensions(this.dimensions);
        }
        return this.client.embeddings().create(builder.build()).thenApply(this::convertToRowData);
    }

    private List<RowData> convertToRowData(CreateEmbeddingResponse response) {
        return response.data().stream().map(embedding2 -> GenericRowData.of((Object[])new Object[]{new GenericArrayData((Object[])embedding2.embedding().stream().map(Double::floatValue).toArray(Float[]::new))})).collect(Collectors.toList());
    }
}

