package cn.org.atool.generator.util;

import cn.org.atool.fluent.mybatis.metadata.DbType;
import lombok.Setter;
import lombok.experimental.Accessors;
import org.mybatis.generator.api.IntrospectedTable;
import org.mybatis.generator.api.JavaTypeResolver;
import org.mybatis.generator.api.ProgressCallback;
import org.mybatis.generator.api.VerboseProgressCallback;
import org.mybatis.generator.config.Context;
import org.mybatis.generator.config.TableConfiguration;
import org.mybatis.generator.internal.ObjectFactory;
import org.mybatis.generator.internal.db.DatabaseIntrospector;

import java.sql.Connection;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

import static org.mybatis.generator.internal.util.StringUtility.composeFullyQualifiedTableName;
import static org.mybatis.generator.internal.util.messages.Messages.getString;

/**
 * 通过mybatis generator获取数据库元数据
 *
 * @author darui.wu
 */
@Accessors(chain = true)
public class TableKits {
    private final DbType dbType;

    private final Connection connection;

    @Setter
    private String schema;

    @Setter
    private JavaTypeResolver javaTypeResolver;

    private final List<String> warnings = new ArrayList<>();

    TableKits(DbType dbType, Connection connection) {
        this.dbType = dbType;
        this.connection = connection;
    }

    public List<IntrospectedTable> findTables(Collection<String> tables) {
        List<IntrospectedTable> introspectedTables = new ArrayList<>();
        try {
            ProgressCallback callback = new VerboseProgressCallback();
            Context ctx = getContext(tables);
            if (javaTypeResolver == null) {
                javaTypeResolver = ObjectFactory.createJavaTypeResolver(ctx, warnings);
            }
            DatabaseIntrospector databaseIntrospector = new DatabaseIntrospector(ctx, connection.getMetaData(), javaTypeResolver, warnings);
            List<TableConfiguration> tcs = ctx.getTableConfigurations();
            for (TableConfiguration tc : tcs) {
                List<IntrospectedTable> ts = introspectTables(tc, callback, databaseIntrospector);
                if (ts != null) {
                    introspectedTables.addAll(ts);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException("find tables error:" + e.getMessage(), e);
        }
        return introspectedTables;
    }

    private List<IntrospectedTable> introspectTables(TableConfiguration tc, ProgressCallback callback, DatabaseIntrospector databaseIntrospector) throws Exception {
        String tableName = composeFullyQualifiedTableName(tc.getCatalog(), tc.getSchema(), tc.getTableName(), '.');

        callback.startTask(getString("Progress.1", tableName));
        List<IntrospectedTable> tables = databaseIntrospector.introspectTables(tc);
        callback.checkCancel();
        return tables;
    }

    private Context getContext(Collection<String> tables) {
        Context ctx = new Context(null);
        ctx.setTargetRuntime("MyBatis3");
        ctx.setId("MySqlContext");
        for (String tableName : tables) {
            ctx.addTableConfiguration(this.tableConfiguration(ctx, tableName));
        }
        return ctx;
    }

    private TableConfiguration tableConfiguration(Context ctx, String tableName) {
        TableConfiguration conf = new TableConfiguration(ctx);
        conf.setDelimitIdentifiers(true);

        int index = this.schema == null ? -1 : this.schema.indexOf('.');
        if (index >= 0) {
            conf.setCatalog(this.schema.substring(0, index));
            conf.setSchema(this.schema.substring(index + 1));
        } else {
            conf.setSchema(this.schema);
            /* MySql取的是catalog, 具体原因见"MyBatis Generator踩坑与自救"
             https://www.jianshu.com/p/dbeeac29ff27 **/
            if (nullCatalogMeansCurrent.contains(dbType)) {
                conf.setCatalog(this.schema);
            }
        }
        conf.setTableName(tableName);
        return conf;
    }

    private static final List<DbType> nullCatalogMeansCurrent = Arrays.asList(
        DbType.MYSQL, DbType.MARIADB
    );
}