package com.alibaba.cloud.ai.service.base;

import com.alibaba.cloud.ai.dbconnector.DbAccessor;
import com.alibaba.cloud.ai.dbconnector.DbConfig;
import com.alibaba.cloud.ai.dbconnector.MdTableGenerator;
import com.alibaba.cloud.ai.dbconnector.bo.DbQueryParameter;
import com.alibaba.cloud.ai.prompt.PromptHelper;
import com.alibaba.cloud.ai.schema.SchemaDTO;
import com.alibaba.cloud.ai.service.LlmService;
import com.alibaba.cloud.ai.util.DateTimeUtil;
import com.alibaba.cloud.ai.util.MarkdownParser;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:com/alibaba/cloud/ai/service/base/BaseNl2SqlService.class */
public class BaseNl2SqlService {
    protected final BaseVectorStoreService vectorStoreService;
    protected final BaseSchemaService schemaService;
    public final LlmService aiService;
    protected final DbAccessor dbAccessor;
    protected final DbConfig dbConfig;

    public BaseNl2SqlService(BaseVectorStoreService baseVectorStoreService, BaseSchemaService baseSchemaService, LlmService llmService, DbAccessor dbAccessor, DbConfig dbConfig) {
        this.vectorStoreService = baseVectorStoreService;
        this.schemaService = baseSchemaService;
        this.aiService = llmService;
        this.dbAccessor = dbAccessor;
        this.dbConfig = dbConfig;
    }

    public String rewrite(String str) throws Exception {
        List<String> list = (List) this.vectorStoreService.getDocuments(str, "evidence").stream().map((v0) -> {
            return v0.getText();
        }).collect(Collectors.toList());
        for (String str2 : this.aiService.call(PromptHelper.buildRewritePrompt(str, select(str, list), list)).split("\\n")) {
            if (str2.startsWith("需求类型：")) {
                String trim = str2.substring(5).trim();
                if ("《自由闲聊》".equals(trim)) {
                    return "闲聊拒识";
                }
                if ("《需要澄清》".equals(trim)) {
                    return "意图模糊需要澄清";
                }
            } else if (str2.startsWith("需求内容：")) {
                str = str2.substring(5);
            }
        }
        return str;
    }

    public String nl2sql(String str) throws Exception {
        List<String> list = (List) this.vectorStoreService.getDocuments(str, "evidence").stream().map((v0) -> {
            return v0.getText();
        }).collect(Collectors.toList());
        return generateSql(list, str, select(str, list));
    }

    public String executeSql(String str) throws Exception {
        return MdTableGenerator.generateTable(this.dbAccessor.executeSqlAndReturnObject(this.dbConfig, DbQueryParameter.from(this.dbConfig).setSql(str)));
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [com.alibaba.cloud.ai.service.base.BaseNl2SqlService$1] */
    public SchemaDTO select(String str, List<String> list) throws Exception {
        StringBuilder sb = new StringBuilder(str);
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            sb.append(it.next()).append("。");
        }
        String sb2 = sb.toString();
        return fineSelect(this.schemaService.mixRag(sb2, (List) new Gson().fromJson(this.aiService.call(PromptHelper.buildQueryToKeywordsPrompt(sb2)), new TypeToken<List<String>>() { // from class: com.alibaba.cloud.ai.service.base.BaseNl2SqlService.1
        }.getType())), sb2, list);
    }

    /* JADX WARN: Type inference failed for: r2v0, types: [com.alibaba.cloud.ai.service.base.BaseNl2SqlService$2] */
    public String generateSql(List<String> list, String str, SchemaDTO schemaDTO) throws Exception {
        String call = this.aiService.call(PromptHelper.buildDateTimeExtractPrompt(str));
        ArrayList arrayList = new ArrayList();
        LocalDate now = LocalDate.now();
        List list2 = (List) new Gson().fromJson(call, new TypeToken<List<String>>() { // from class: com.alibaba.cloud.ai.service.base.BaseNl2SqlService.2
        }.getType());
        for (String str2 : DateTimeUtil.buildDateExpressions(list2, now)) {
            if (!str2.endsWith("=")) {
                arrayList.add(str2.replace("=", "指的是"));
            }
        }
        list2.addAll(arrayList);
        List<String> buildMixSqlGeneratorPrompt = PromptHelper.buildMixSqlGeneratorPrompt(str, this.dbConfig, schemaDTO, list);
        return MarkdownParser.extractRawText(this.aiService.callWithSystemPrompt(buildMixSqlGeneratorPrompt.get(0), buildMixSqlGeneratorPrompt.get(1))).trim();
    }

    /* JADX WARN: Type inference failed for: r2v2, types: [com.alibaba.cloud.ai.service.base.BaseNl2SqlService$3] */
    public SchemaDTO fineSelect(SchemaDTO schemaDTO, String str, List<String> list) {
        String call = this.aiService.call(PromptHelper.buildMixSelectorPrompt(list, str, schemaDTO));
        if (call != null && !call.trim().isEmpty()) {
            String extractText = MarkdownParser.extractText(call);
            try {
                List list2 = (List) new Gson().fromJson(extractText, new TypeToken<List<String>>() { // from class: com.alibaba.cloud.ai.service.base.BaseNl2SqlService.3
                }.getType());
                if (list2 != null && !list2.isEmpty()) {
                    Set set = (Set) list2.stream().map((v0) -> {
                        return v0.toLowerCase();
                    }).collect(Collectors.toSet());
                    schemaDTO.getTable().removeIf(tableDTO -> {
                        return !set.contains(tableDTO.getName().toLowerCase());
                    });
                }
            } catch (Exception e) {
                throw new IllegalStateException(extractText);
            }
        }
        return schemaDTO;
    }
}
