package com.github.database.rider.junit5;

import com.github.database.rider.core.RiderRunner;
import com.github.database.rider.core.api.configuration.DBUnit;
import com.github.database.rider.core.api.configuration.DataSetMergingStrategy;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.dataset.ExpectedDataSet;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.configuration.DataSetConfig;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.junit5.jdbc.ConnectionManager;
import com.github.database.rider.junit5.util.Constants;
import com.github.database.rider.junit5.util.EntityManagerProvider;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Stream;
import org.dbunit.DatabaseUnitException;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.util.AnnotationUtils;
import org.junit.platform.commons.util.StringUtils;

/* loaded from: input_file:com/github/database/rider/junit5/DBUnitExtension.class */
public class DBUnitExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback, BeforeEachCallback, AfterEachCallback, BeforeAllCallback, AfterAllCallback {
    private static final Logger LOG = Logger.getLogger(DBUnitExtension.class.getName());

    public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
        EntityManagerProvider.clear();
        DBUnitTestContext testContext = getTestContext(extensionContext);
        DataSetExecutor executor = testContext.getExecutor();
        DBUnitConfig resolveDbUnitConfig = resolveDbUnitConfig(Optional.empty(), extensionContext.getTestMethod(), extensionContext.getRequiredTestClass());
        executor.setDBUnitConfig(resolveDbUnitConfig);
        if (resolveDbUnitConfig.isLeakHunter().booleanValue()) {
            try {
                LeakHunter from = LeakHunterFactory.from(executor.getRiderDataSource(), extensionContext.getRequiredTestMethod().getName(), resolveDbUnitConfig.isCacheConnection().booleanValue());
                from.measureConnectionsBeforeExecution();
                testContext.setLeakHunter(from);
            } catch (SQLException e) {
                LOG.log(Level.WARNING, String.format("Could not create leak hunter for test %s", extensionContext.getRequiredTestMethod().getName()), (Throwable) e);
            }
        }
        JUnit5RiderTestContext jUnit5RiderTestContext = new JUnit5RiderTestContext(testContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        riderRunner.setup(jUnit5RiderTestContext);
        riderRunner.runBeforeTest(jUnit5RiderTestContext);
    }

    public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
        DBUnitTestContext testContext = getTestContext(extensionContext);
        DBUnitConfig dBUnitConfig = testContext.getExecutor().getDBUnitConfig();
        JUnit5RiderTestContext jUnit5RiderTestContext = new JUnit5RiderTestContext(testContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        try {
            riderRunner.runAfterTest(jUnit5RiderTestContext);
            if (dBUnitConfig != null && dBUnitConfig.isLeakHunter().booleanValue()) {
                testContext.getLeakHunter().checkConnectionsAfterExecution();
            }
        } finally {
            riderRunner.teardown(jUnit5RiderTestContext);
        }
    }

    private DBUnitTestContext getTestContext(ExtensionContext extensionContext) {
        return (DBUnitTestContext) extensionContext.getStore(Constants.NAMESPACE).getOrComputeIfAbsent(extensionContext.getRequiredTestClass(), cls -> {
            return createDBUnitTestContext(extensionContext);
        }, DBUnitTestContext.class);
    }

    private DBUnitTestContext createDBUnitTestContext(ExtensionContext extensionContext) {
        String executorId = getExecutorId(extensionContext, null);
        return new DBUnitTestContext(DataSetExecutorImpl.instance(executorId, ConnectionManager.getTestConnection(extensionContext, executorId)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Set<Method> findCallbackMethods(Class cls, Class cls2) {
        HashSet hashSet = new HashSet();
        Stream.of((Object[]) new Method[]{cls.getSuperclass().getDeclaredMethods(), cls.getDeclaredMethods()}).flatMap((v0) -> {
            return Stream.of(v0);
        }).filter(method -> {
            return method.getAnnotation(cls2) != null;
        }).forEach(method2 -> {
            hashSet.add(method2);
        });
        return Collections.unmodifiableSet(hashSet);
    }

    public void beforeEach(ExtensionContext extensionContext) throws Exception {
        if (extensionContext.getTestClass().isPresent()) {
            Set<Method> findCallbackMethods = findCallbackMethods((Class) extensionContext.getTestClass().get(), BeforeEach.class);
            if (findCallbackMethods.isEmpty()) {
                return;
            }
            for (Method method : findCallbackMethods) {
                executeDataSetForCallback(extensionContext, BeforeEach.class, method);
                executeExpectedDataSetForCallback(extensionContext, BeforeEach.class, method);
            }
        }
    }

    public void afterEach(ExtensionContext extensionContext) throws Exception {
        if (extensionContext.getTestClass().isPresent()) {
            Set<Method> findCallbackMethods = findCallbackMethods((Class) extensionContext.getTestClass().get(), AfterEach.class);
            if (findCallbackMethods.isEmpty()) {
                return;
            }
            for (Method method : findCallbackMethods) {
                executeDataSetForCallback(extensionContext, AfterEach.class, method);
                executeExpectedDataSetForCallback(extensionContext, AfterEach.class, method);
            }
        }
    }

    public void beforeAll(ExtensionContext extensionContext) throws Exception {
        if (extensionContext.getTestClass().isPresent()) {
            Set<Method> findCallbackMethods = findCallbackMethods((Class) extensionContext.getTestClass().get(), BeforeAll.class);
            if (findCallbackMethods.isEmpty()) {
                return;
            }
            for (Method method : findCallbackMethods) {
                executeDataSetForCallback(extensionContext, BeforeAll.class, method);
                executeExpectedDataSetForCallback(extensionContext, BeforeAll.class, method);
            }
        }
    }

    public void afterAll(ExtensionContext extensionContext) throws Exception {
        if (extensionContext.getTestClass().isPresent()) {
            Set<Method> findCallbackMethods = findCallbackMethods((Class) extensionContext.getTestClass().get(), AfterAll.class);
            if (findCallbackMethods.isEmpty()) {
                return;
            }
            for (Method method : findCallbackMethods) {
                executeDataSetForCallback(extensionContext, AfterAll.class, method);
                executeExpectedDataSetForCallback(extensionContext, AfterAll.class, method);
            }
        }
    }

    private void executeDataSetForCallback(ExtensionContext extensionContext, Class cls, Method method) throws SQLException {
        Class cls2 = (Class) extensionContext.getTestClass().get();
        Optional<DataSet> findAnnotation = AnnotationUtils.findAnnotation(method, DataSet.class);
        if (!findAnnotation.isPresent()) {
            LOG.warning("Could not find dataset annotation from callback method: " + method);
            return;
        }
        EntityManagerProvider.clear();
        DBUnitTestContext testContext = getTestContext(extensionContext);
        DBUnitConfig resolveDbUnitConfig = resolveDbUnitConfig(Optional.of(cls), Optional.of(method), cls2);
        DataSet resolveDataSet = resolveDbUnitConfig.isMergeDataSets().booleanValue() ? resolveDataSet(findAnnotation, AnnotationUtils.findAnnotation(cls2, DataSet.class), resolveDbUnitConfig) : findAnnotation.get();
        DataSetExecutor executor = testContext.getExecutor();
        executor.setDBUnitConfig(resolveDbUnitConfig);
        DataSetExecutor resetExecutorConnectionIfNeeded = resetExecutorConnectionIfNeeded(extensionContext, cls, resolveDbUnitConfig, executor);
        resetExecutorConnectionIfNeeded.createDataSet(new DataSetConfig().from(resolveDataSet));
        closeConnectionForAfterCallback(resetExecutorConnectionIfNeeded, cls);
    }

    private void closeConnectionForAfterCallback(DataSetExecutor dataSetExecutor, Class cls) throws SQLException {
        if (!isAfterTestCallback(cls) || dataSetExecutor.getDBUnitConfig().isCacheConnection().booleanValue() || dataSetExecutor.getRiderDataSource().getDBUnitConnection().getConnection().isClosed()) {
            return;
        }
        dataSetExecutor.getRiderDataSource().getDBUnitConnection().getConnection().close();
        ((DataSetExecutorImpl) dataSetExecutor).clearRiderDataSource();
    }

    private void executeExpectedDataSetForCallback(ExtensionContext extensionContext, Class cls, Method method) throws DatabaseUnitException, SQLException {
        Class cls2 = (Class) extensionContext.getTestClass().get();
        Optional findAnnotation = AnnotationUtils.findAnnotation(method, ExpectedDataSet.class);
        if (!findAnnotation.isPresent()) {
            LOG.warning("Could not find expectedDataSet annotation annotation from callback method: " + method);
            return;
        }
        ExpectedDataSet expectedDataSet = (ExpectedDataSet) findAnnotation.get();
        DBUnitConfig resolveDbUnitConfig = resolveDbUnitConfig(Optional.of(cls), Optional.of(method), cls2);
        DataSetExecutor executor = getTestContext(extensionContext).getExecutor();
        executor.setDBUnitConfig(resolveDbUnitConfig);
        DataSetExecutor resetExecutorConnectionIfNeeded = resetExecutorConnectionIfNeeded(extensionContext, cls, resolveDbUnitConfig, executor);
        resetExecutorConnectionIfNeeded.compareCurrentDataSetWith(new DataSetConfig(expectedDataSet.value()).disableConstraints(true).datasetProvider(expectedDataSet.provider()), expectedDataSet.ignoreCols(), expectedDataSet.replacers(), expectedDataSet.orderBy(), expectedDataSet.compareOperation());
        closeConnectionForAfterCallback(resetExecutorConnectionIfNeeded, cls);
    }

    private DataSetExecutor resetExecutorConnectionIfNeeded(ExtensionContext extensionContext, Class cls, DBUnitConfig dBUnitConfig, DataSetExecutor dataSetExecutor) {
        if (!dBUnitConfig.isCacheConnection().booleanValue() && isAfterTestCallback(cls)) {
            dataSetExecutor = DataSetExecutorImpl.instance(dataSetExecutor.getExecutorId(), ConnectionManager.getTestConnection(extensionContext, dataSetExecutor.getExecutorId()), dBUnitConfig);
        }
        return dataSetExecutor;
    }

    private boolean isAfterTestCallback(Class cls) {
        return cls.equals(AfterEach.class) || cls.equals(AfterAll.class);
    }

    private DBUnitConfig resolveDbUnitConfig(Optional<Class> optional, Optional<Method> optional2, Class cls) {
        Optional findAnnotation = AnnotationUtils.findAnnotation(optional2, DBUnit.class);
        if (!findAnnotation.isPresent()) {
            findAnnotation = AnnotationUtils.findAnnotation(cls, DBUnit.class);
        }
        if (!findAnnotation.isPresent() && optional.isPresent()) {
            Set<Method> findCallbackMethods = findCallbackMethods(cls, optional.get());
            if (!findCallbackMethods.isEmpty()) {
                findAnnotation = AnnotationUtils.findAnnotation(findCallbackMethods.iterator().next(), DBUnit.class);
            }
        }
        if (!findAnnotation.isPresent() && cls.getSuperclass() != null) {
            findAnnotation = AnnotationUtils.findAnnotation(cls.getSuperclass(), DBUnit.class);
        }
        return findAnnotation.isPresent() ? DBUnitConfig.from((DBUnit) findAnnotation.get()) : DBUnitConfig.fromGlobalConfig();
    }

    private DataSet resolveDataSet(Optional<DataSet> optional, Optional<DataSet> optional2, DBUnitConfig dBUnitConfig) {
        return optional2.isPresent() ? DataSetMergingStrategy.METHOD.equals(dBUnitConfig.getMergingStrategy()) ? com.github.database.rider.core.util.AnnotationUtils.mergeDataSetAnnotations(optional2.get(), optional.get()) : com.github.database.rider.core.util.AnnotationUtils.mergeDataSetAnnotations(optional.get(), optional2.get()) : optional.get();
    }

    private String getExecutorId(ExtensionContext extensionContext, DataSet dataSet) {
        Optional<DataSet> of = dataSet != null ? Optional.of(dataSet) : findDataSetAnnotation(extensionContext);
        String configuredDataSourceBeanName = ConnectionManager.getConfiguredDataSourceBeanName(extensionContext);
        String str = configuredDataSourceBeanName.isEmpty() ? Constants.EMPTY_STRING : "-" + configuredDataSourceBeanName;
        return (String) of.map((v0) -> {
            return v0.executorId();
        }).filter(StringUtils::isNotBlank).map(str2 -> {
            return str2 + str;
        }).orElseGet(() -> {
            return Constants.JUNIT5_EXECUTOR + str;
        });
    }

    private Optional<DataSet> findDataSetAnnotation(ExtensionContext extensionContext) {
        Optional testMethod = extensionContext.getTestMethod();
        if (!testMethod.isPresent()) {
            return Optional.empty();
        }
        Optional<DataSet> findAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement) testMethod.get(), DataSet.class);
        if (!findAnnotation.isPresent()) {
            findAnnotation = AnnotationUtils.findAnnotation(extensionContext.getRequiredTestClass(), DataSet.class);
        }
        return findAnnotation;
    }
}
