/*
 * Decompiled with CFR 0.152.
 */
package com.mendmix.mybatis.plugin.pagination;

import com.mendmix.common.model.PageParams;
import com.mendmix.mybatis.datasource.DatabaseType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;

public class PageSqlUtils {
    private static final String[] SQL_LINE_CHARS = new String[]{"\r", "\n", "\t"};
    private static final String[] SQL_LINE_REPLACE_CHARS = new String[]{" ", " ", " "};
    private static final String PAGE_SIZE_PLACEHOLDER = "#{pageSize}";
    private static final String OFFSET_PLACEHOLDER = "#{offset}";
    private static final String SQL_SELECT_PATTERN = "(select|SELECT).*?(?=from|FROM)";
    private static final String SQL_ORDER_PATTERN = "(order|ORDER)\\s+(by|BY)";
    private static final String SQL_COUNT_PREFIX = "SELECT count(1) ";
    private static String[] unionKeys = new String[]{" UNION ", " union "};
    private static List<Pattern> aggregationKeyPatterns = new ArrayList<Pattern>();
    private static Pattern nestSelectPattern = Pattern.compile("\\(\\s{0,}(select|SELECT)\\s+");
    private static Map<String, String> pageTemplates = new HashMap<String, String>(4);
    private static String commonCountSqlTemplate = "select count(1) from (%s) tmp";

    public static String getLimitSQL(DatabaseType dbType, String sql) {
        return String.format(pageTemplates.get(dbType.name()), sql);
    }

    public static String getLimitSQL(DatabaseType dbType, String sql, PageParams pageParams) {
        return PageSqlUtils.getLimitSQL(dbType, sql).replace(OFFSET_PLACEHOLDER, String.valueOf(pageParams.offset())).replace(PAGE_SIZE_PLACEHOLDER, String.valueOf(pageParams.getPageSize()));
    }

    public static String getCountSql(String sql) {
        String formatSql = StringUtils.replaceEach((String)sql, (String[])SQL_LINE_CHARS, (String[])SQL_LINE_REPLACE_CHARS);
        if (StringUtils.containsAny((CharSequence)formatSql, (CharSequence[])unionKeys) || aggregationKeyPatterns.stream().anyMatch(p -> p.matcher(formatSql).find()) || nestSelectPattern.matcher(formatSql).find()) {
            return String.format(commonCountSqlTemplate, formatSql);
        }
        sql = formatSql.split(SQL_ORDER_PATTERN)[0];
        return sql.replaceFirst(SQL_SELECT_PATTERN, SQL_COUNT_PREFIX);
    }

    public static void main(String[] args) throws IOException {
        String sql = "select a.*,\nSUM(a.c) from audited_policy a where 1=1\nand title like CONCAT('%',?,'%')\norder by updated_at desc";
        System.out.println(">>>>" + PageSqlUtils.getCountSql(sql));
        System.out.println(">>>>" + PageSqlUtils.getLimitSQL(DatabaseType.mysql, sql, new PageParams()));
    }

    static {
        pageTemplates.put(DatabaseType.mysql.name(), "%s limit #{offset},#{pageSize}");
        pageTemplates.put(DatabaseType.oracle.name(), "select * from (select a1.*,rownum rn from (%s) a1 where rownum <=#{offset} + #{pageSize}) where rn>=#{offset}");
        pageTemplates.put(DatabaseType.postgresql.name(), "%s limit #{pageSize} offset #{offset}");
        pageTemplates.put(DatabaseType.h2.name(), "%s limit #{pageSize} offset #{offset}");
        aggregationKeyPatterns.add(Pattern.compile("\\s+GROUP\\s+BY\\s+", 2));
        aggregationKeyPatterns.add(Pattern.compile("(\\s+|,)(COUNT|MIN|MAX|SUM|AVG)\\(", 2));
        aggregationKeyPatterns.add(Pattern.compile("(\\s+|,)DISTINCT", 2));
    }
}

