/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.model.openai;

import com.openai.client.OpenAIClientAsync;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.ConfigOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.configuration.description.Description;
import org.apache.flink.configuration.description.InlineElement;
import org.apache.flink.configuration.description.TextElement;
import org.apache.flink.model.openai.OpenAIUtils;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.catalog.Column;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.VarCharType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractOpenAIModelFunction
extends AsyncPredictFunction {
    private static final Logger LOG = LoggerFactory.getLogger(AbstractOpenAIModelFunction.class);
    public static final ConfigOption<String> ENDPOINT = ConfigOptions.key((String)"endpoint").stringType().noDefaultValue().withDescription(Description.builder().text("Full URL of the OpenAI API endpoint, e.g., %s or %s", new InlineElement[]{TextElement.code((String)"https://api.openai.com/v1/chat/completions"), TextElement.code((String)"https://api.openai.com/v1/embeddings")}).build());
    public static final ConfigOption<String> API_KEY = ConfigOptions.key((String)"api-key").stringType().noDefaultValue().withDescription("OpenAI API key for authentication.");
    public static final ConfigOption<String> MODEL = ConfigOptions.key((String)"model").stringType().noDefaultValue().withDescription(Description.builder().text("Model name, e.g., %s, %s.", new InlineElement[]{TextElement.code((String)"gpt-3.5-turbo"), TextElement.code((String)"text-embedding-ada-002")}).build());
    protected transient OpenAIClientAsync client;
    private final int numRetry;
    private final String baseUrl;
    private final String apiKey;

    public AbstractOpenAIModelFunction(ModelProviderFactory.Context factoryContext, ReadableConfig config) {
        String endpoint = (String)config.get(ENDPOINT);
        this.baseUrl = endpoint.replaceAll(String.format("/%s/*$", this.getEndpointSuffix()), "");
        this.apiKey = (String)config.get(API_KEY);
        this.numRetry = (Integer)config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) * 10;
        this.validateSingleColumnSchema(factoryContext.getCatalogModel().getResolvedInputSchema(), (LogicalType)new VarCharType(Integer.MAX_VALUE), "input");
    }

    public void open(FunctionContext context) throws Exception {
        super.open(context);
        LOG.debug("Creating an OpenAI client.");
        this.client = OpenAIUtils.createAsyncClient(this.baseUrl, this.apiKey, this.numRetry);
    }

    public void close() throws Exception {
        super.close();
        if (this.client != null) {
            LOG.debug("Releasing the OpenAI client.");
            OpenAIUtils.releaseAsyncClient(this.baseUrl, this.apiKey);
            this.client = null;
        }
    }

    protected abstract String getEndpointSuffix();

    protected void validateSingleColumnSchema(ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) {
        List columns = schema.getColumns();
        if (columns.size() != 1) {
            throw new IllegalArgumentException(String.format("Model should have exactly one %s column, but actually has %s columns: %s", inputOrOutput, columns.size(), columns.stream().map(Column::getName).collect(Collectors.toList())));
        }
        Column column = (Column)columns.get(0);
        if (!column.isPhysical()) {
            throw new IllegalArgumentException(String.format("%s column %s should be a physical column, but is a %s.", inputOrOutput, column.getName(), column.getClass()));
        }
        if (!expectedType.equals((Object)column.getDataType().getLogicalType())) {
            throw new IllegalArgumentException(String.format("%s column %s should be %s, but is a %s.", inputOrOutput, column.getName(), expectedType, column.getDataType().getLogicalType()));
        }
    }
}

