/*
 * Decompiled with CFR 0.152.
 */
package io.trino.benchto.driver.execution;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListeningExecutorService;
import io.trino.benchto.driver.Benchmark;
import io.trino.benchto.driver.BenchmarkExecutionException;
import io.trino.benchto.driver.BenchmarkProperties;
import io.trino.benchto.driver.Query;
import io.trino.benchto.driver.concurrent.ExecutorServiceFactory;
import io.trino.benchto.driver.execution.BenchmarkExecutionResult;
import io.trino.benchto.driver.execution.ExecutionSynchronizer;
import io.trino.benchto.driver.execution.QueryExecution;
import io.trino.benchto.driver.execution.QueryExecutionDriver;
import io.trino.benchto.driver.execution.QueryExecutionResult;
import io.trino.benchto.driver.execution.ResultComparisonException;
import io.trino.benchto.driver.listeners.benchmark.BenchmarkStatusReporter;
import io.trino.benchto.driver.loader.SqlStatementGenerator;
import io.trino.benchto.driver.macro.MacroService;
import io.trino.benchto.driver.utils.PermutationUtils;
import io.trino.benchto.driver.utils.QueryUtils;
import io.trino.benchto.driver.utils.TimeUtils;
import java.nio.file.Path;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.sql.DataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;

@Component
public class BenchmarkExecutionDriver {
    private static final Logger LOG = LoggerFactory.getLogger(BenchmarkExecutionDriver.class);
    @Autowired
    private QueryExecutionDriver queryExecutionDriver;
    @Autowired
    private BenchmarkStatusReporter statusReporter;
    @Autowired
    private ExecutorServiceFactory executorServiceFactory;
    @Autowired
    private MacroService macroService;
    @Autowired
    private ExecutionSynchronizer executionSynchronizer;
    @Autowired
    private ApplicationContext applicationContext;
    @Autowired
    private BenchmarkProperties properties;
    @Autowired
    private SqlStatementGenerator sqlStatementGenerator;

    public List<BenchmarkExecutionResult> execute(List<Benchmark> benchmarks, int benchmarkOrdinalNumber, int benchmarkTotalCount, Optional<ZonedDateTime> executionTimeLimit) {
        Preconditions.checkState((benchmarks.size() != 0 ? 1 : 0) != 0, (Object)"List of benchmarks to execute cannot be empty.");
        for (int i = 0; i < benchmarks.size(); ++i) {
            LOG.info("[{} of {}] processing benchmark: {}", new Object[]{benchmarkOrdinalNumber + i, benchmarkTotalCount, benchmarks.get(i)});
        }
        Benchmark firstBenchmark = benchmarks.get(0);
        Preconditions.checkState((boolean)benchmarks.stream().allMatch(benchmark -> benchmark.getBeforeBenchmarkMacros().equals(firstBenchmark.getBeforeBenchmarkMacros()) && benchmark.getAfterBenchmarkMacros().equals(firstBenchmark.getAfterBenchmarkMacros())), (Object)"All benchmarks in a group must have the same before and after benchmark macros.");
        Preconditions.checkState((boolean)benchmarks.stream().allMatch(benchmark -> benchmark.getRuns() == firstBenchmark.getRuns() && benchmark.getPrewarmRuns() == firstBenchmark.getPrewarmRuns()), (Object)"All benchmarks in a group must have the same number of runs and prewarm-runs.");
        Preconditions.checkState((boolean)benchmarks.stream().allMatch(benchmark -> benchmark.getConcurrency() == firstBenchmark.getConcurrency() && benchmark.isThroughputTest() == firstBenchmark.isThroughputTest()), (Object)"All benchmarks in a group must have the same concurrency and either test throughput or not.");
        try {
            this.macroService.runBenchmarkMacros(firstBenchmark.getBeforeBenchmarkMacros(), firstBenchmark);
        }
        catch (Exception e) {
            return List.of(this.failedBenchmarkResult(firstBenchmark, e));
        }
        List<BenchmarkExecutionResult> benchmarkExecutionResults = this.properties.isWarmup() ? this.warmupBenchmarks(benchmarks, executionTimeLimit) : this.executeBenchmarks(benchmarks, executionTimeLimit);
        try {
            this.macroService.runBenchmarkMacros(firstBenchmark.getAfterBenchmarkMacros(), firstBenchmark);
        }
        catch (Exception e) {
            if (benchmarkExecutionResults.stream().allMatch(BenchmarkExecutionResult::isSuccessful)) {
                return List.of(this.failedBenchmarkResult(firstBenchmark, e));
            }
            LOG.error("Error while running after benchmark macros for successful benchmark({})", firstBenchmark.getAfterBenchmarkMacros(), (Object)e);
        }
        return benchmarkExecutionResults;
    }

    private List<BenchmarkExecutionResult> warmupBenchmarks(List<Benchmark> benchmarks, Optional<ZonedDateTime> executionTimeLimit) {
        List<QueryExecutionResult> executions;
        Benchmark firstBenchmark = benchmarks.get(0);
        Map results = benchmarks.stream().collect(Collectors.toMap(Function.identity(), benchmark -> new BenchmarkExecutionResult.BenchmarkExecutionResultBuilder((Benchmark)benchmark).withExecutions(List.of())));
        try {
            executions = this.executeQueries(benchmarks, firstBenchmark.getPrewarmRuns(), true, executionTimeLimit);
        }
        catch (Exception e) {
            return results.values().stream().map(builder -> builder.withUnexpectedException(e).build()).collect(Collectors.toList());
        }
        Map<Benchmark, String> comparisonFailures = BenchmarkExecutionDriver.getComparisonFailures(executions);
        return results.entrySet().stream().map(entry -> {
            Benchmark benchmark = (Benchmark)entry.getKey();
            BenchmarkExecutionResult.BenchmarkExecutionResultBuilder builder = (BenchmarkExecutionResult.BenchmarkExecutionResultBuilder)entry.getValue();
            String failure = comparisonFailures.getOrDefault(benchmark, "");
            if (!failure.isEmpty()) {
                builder.withUnexpectedException(new RuntimeException(String.format("Query result comparison failed for queries: %s", failure)));
            }
            return builder.build();
        }).collect(Collectors.toList());
    }

    private List<BenchmarkExecutionResult> executeBenchmarks(List<Benchmark> benchmarks, Optional<ZonedDateTime> executionTimeLimit) {
        List<QueryExecutionResult> executions;
        Benchmark firstBenchmark = benchmarks.get(0);
        Map results = benchmarks.stream().collect(Collectors.toMap(Function.identity(), benchmark -> new BenchmarkExecutionResult.BenchmarkExecutionResultBuilder((Benchmark)benchmark).withExecutions(List.of())));
        try {
            executions = this.executeQueries(benchmarks, firstBenchmark.getPrewarmRuns(), true, executionTimeLimit);
        }
        catch (Exception e) {
            return results.values().stream().map(builder -> builder.withUnexpectedException(e).build()).collect(Collectors.toList());
        }
        Map<Benchmark, String> comparisonFailures = BenchmarkExecutionDriver.getComparisonFailures(executions);
        ArrayList<Benchmark> validBenchmarks = new ArrayList<Benchmark>(benchmarks);
        for (Map.Entry entry : results.entrySet()) {
            Benchmark benchmark2 = (Benchmark)entry.getKey();
            BenchmarkExecutionResult.BenchmarkExecutionResultBuilder result = entry.getValue();
            this.executionSynchronizer.awaitAfterBenchmarkExecutionAndBeforeResultReport(benchmark2);
            this.statusReporter.reportBenchmarkStarted(benchmark2);
            result.startTimer();
            String failure = comparisonFailures.getOrDefault(benchmark2, "");
            if (failure.isEmpty()) continue;
            result.withUnexpectedException(new RuntimeException(String.format("Query result comparison failed for queries: %s", failure)));
            result.endTimer();
            validBenchmarks.remove(benchmark2);
        }
        try {
            executions = this.executeQueries(validBenchmarks, firstBenchmark.getRuns(), false, executionTimeLimit);
        }
        catch (Exception e) {
            return results.values().stream().map(builder -> builder.withUnexpectedException(e).build()).collect(Collectors.toList());
        }
        Map groups = executions.stream().collect(Collectors.groupingBy(QueryExecutionResult::getBenchmark, LinkedHashMap::new, Collectors.toList()));
        groups.forEach((key, value) -> ((BenchmarkExecutionResult.BenchmarkExecutionResultBuilder)results.get(key)).withExecutions((List<QueryExecutionResult>)value).endTimer());
        return (List)results.values().stream().map(builder -> {
            BenchmarkExecutionResult result = builder.build();
            this.statusReporter.reportBenchmarkFinished(result);
            return result;
        }).collect(ImmutableList.toImmutableList());
    }

    private static Map<Benchmark, String> getComparisonFailures(List<QueryExecutionResult> executions) {
        Map groups = executions.stream().collect(Collectors.groupingBy(QueryExecutionResult::getBenchmark, LinkedHashMap::new, Collectors.toList()));
        return groups.entrySet().stream().filter(entry -> ((List)entry.getValue()).stream().anyMatch(execution -> execution.getFailureCause() != null && execution.getFailureCause().getClass().equals(ResultComparisonException.class))).collect(Collectors.toMap(Map.Entry::getKey, entry -> ((List)entry.getValue()).stream().filter(execution -> execution.getFailureCause() != null && execution.getFailureCause().getClass().equals(ResultComparisonException.class)).map(execution -> String.format("%s [%s]", execution.getQueryName(), execution.getFailureCause())).distinct().collect(Collectors.joining("\n"))));
    }

    private BenchmarkExecutionResult failedBenchmarkResult(Benchmark benchmark, Exception e) {
        return new BenchmarkExecutionResult.BenchmarkExecutionResultBuilder(benchmark).withUnexpectedException(e).build();
    }

    private List<QueryExecutionResult> executeQueries(List<Benchmark> benchmarks, int runs, boolean warmup, Optional<ZonedDateTime> executionTimeLimit) {
        if (benchmarks.size() == 0) {
            return List.of();
        }
        Benchmark firstBenchmark = benchmarks.get(0);
        ListeningExecutorService executorService = this.executorServiceFactory.create(firstBenchmark.getConcurrency());
        try {
            if (firstBenchmark.isThroughputTest()) {
                List queryExecutionCallables = (List)benchmarks.stream().flatMap(benchmark -> this.buildConcurrencyQueryExecutionCallables((Benchmark)benchmark, runs, warmup, executionTimeLimit).stream()).collect(ImmutableList.toImmutableList());
                List executionFutures = executorService.invokeAll((Collection)queryExecutionCallables);
                List list = (List)((List)Futures.allAsList((Iterable)executionFutures).get()).stream().flatMap(Collection::stream).collect(ImmutableList.toImmutableList());
                return list;
            }
            List queryExecutionCallables = IntStream.rangeClosed(1, runs).boxed().flatMap(run -> benchmarks.stream().flatMap(benchmark -> this.buildQueryExecutionCallables((Benchmark)benchmark, (int)run, warmup).stream())).collect(Collectors.toList());
            List executionFutures = executorService.invokeAll(queryExecutionCallables);
            List list = (List)Futures.allAsList((Iterable)executionFutures).get();
            return list;
        }
        catch (InterruptedException | ExecutionException e) {
            throw new BenchmarkExecutionException("Could not execute benchmark", e);
        }
        finally {
            executorService.shutdown();
        }
    }

    private List<Callable<QueryExecutionResult>> buildQueryExecutionCallables(Benchmark benchmark, int run, boolean warmup) {
        ArrayList executionCallables = Lists.newArrayList();
        for (Query query : benchmark.getQueries()) {
            QueryExecution queryExecution = new QueryExecution(benchmark, query, run, this.sqlStatementGenerator);
            Optional<Path> resultFile = benchmark.getQueryResults().filter(dir -> warmup && run == 1 || !QueryUtils.isSelectQuery(query.getSqlTemplate())).map(queryResult -> this.properties.getQueryResultsDir().resolve((String)queryResult));
            executionCallables.add(() -> {
                try (Connection connection = this.getConnectionFor(queryExecution);){
                    QueryExecutionResult queryExecutionResult = this.executeSingleQuery(queryExecution, benchmark, connection, warmup, Optional.empty(), resultFile);
                    return queryExecutionResult;
                }
            });
        }
        return executionCallables;
    }

    private List<Callable<List<QueryExecutionResult>>> buildConcurrencyQueryExecutionCallables(Benchmark benchmark, int runs, boolean warmup, Optional<ZonedDateTime> executionTimeLimit) {
        ArrayList executionCallables = Lists.newArrayList();
        int thread = 0;
        while (thread < benchmark.getConcurrency()) {
            int finalThread = thread++;
            executionCallables.add(() -> {
                LOG.info("Running throughput test: {} queries, {} runs", (Object)benchmark.getQueries().size(), (Object)runs);
                int[] queryOrder = PermutationUtils.preparePermutation(benchmark.getQueries().size(), finalThread);
                List<QueryExecutionResult> queryExecutionResults = this.executeConcurrentQueries(benchmark, runs, warmup, executionTimeLimit, finalThread, queryOrder);
                if (!warmup) {
                    this.statusReporter.reportConcurrencyTestExecutionFinished(queryExecutionResults);
                }
                return queryExecutionResults;
            });
        }
        return executionCallables;
    }

    private List<QueryExecutionResult> executeConcurrentQueries(Benchmark benchmark, int runs, boolean warmup, Optional<ZonedDateTime> executionTimeLimit, int threadNumber, int[] queryOrder) throws SQLException {
        boolean firstQuery = true;
        ArrayList queryExecutionResults = Lists.newArrayList();
        try (Connection connection = this.getConnectionFor(new QueryExecution(benchmark, benchmark.getQueries().get(0), 0, this.sqlStatementGenerator));){
            for (int run = 1; run <= runs; ++run) {
                for (int queryIndex = 0; queryIndex < benchmark.getQueries().size(); ++queryIndex) {
                    int permutedQueryIndex = queryIndex;
                    if (warmup) {
                        if (queryIndex % benchmark.getConcurrency() != threadNumber) continue;
                        LOG.info("Executing pre-warm query {}", (Object)queryIndex);
                    } else {
                        permutedQueryIndex = queryOrder[queryIndex];
                    }
                    Query query = benchmark.getQueries().get(permutedQueryIndex);
                    int sequenceId = queryIndex + threadNumber * benchmark.getQueries().size() + (run - 1) * benchmark.getConcurrency() * benchmark.getQueries().size();
                    QueryExecution queryExecution = new QueryExecution(benchmark, query, sequenceId, this.sqlStatementGenerator);
                    if (firstQuery && !warmup) {
                        this.statusReporter.reportExecutionStarted(queryExecution);
                        firstQuery = false;
                    }
                    try {
                        queryExecutionResults.add(this.executeSingleQuery(queryExecution, benchmark, connection, true, executionTimeLimit));
                        continue;
                    }
                    catch (TimeLimitException e) {
                        LOG.warn("Interrupting benchmark {} due to time limit exceeded", (Object)benchmark.getName());
                        ArrayList arrayList = queryExecutionResults;
                        if (connection != null) {
                            connection.close();
                        }
                        return arrayList;
                    }
                }
            }
        }
        return queryExecutionResults;
    }

    private QueryExecutionResult executeSingleQuery(QueryExecution queryExecution, Benchmark benchmark, Connection connection, boolean skipReport, Optional<ZonedDateTime> executionTimeLimit) throws TimeLimitException {
        return this.executeSingleQuery(queryExecution, benchmark, connection, skipReport, executionTimeLimit, Optional.empty());
    }

    private QueryExecutionResult executeSingleQuery(QueryExecution queryExecution, Benchmark benchmark, Connection connection, boolean skipReport, Optional<ZonedDateTime> executionTimeLimit, Optional<Path> outputFile) throws TimeLimitException {
        QueryExecutionResult result;
        this.macroService.runBenchmarkMacros(benchmark.getBeforeExecutionMacros(), benchmark, connection);
        if (!skipReport) {
            this.statusReporter.reportExecutionStarted(queryExecution);
        }
        QueryExecutionResult.QueryExecutionResultBuilder failureResult = (QueryExecutionResult.QueryExecutionResultBuilder)new QueryExecutionResult.QueryExecutionResultBuilder(queryExecution).startTimer();
        try {
            result = this.queryExecutionDriver.execute(queryExecution, connection, outputFile);
        }
        catch (Exception e) {
            LOG.error(String.format("Query Execution failed for benchmark %s query %s", benchmark.getName(), queryExecution.getQueryName()), (Throwable)e);
            result = (QueryExecutionResult)((QueryExecutionResult.QueryExecutionResultBuilder)failureResult.endTimer()).failed(e).build();
        }
        if (this.isTimeLimitExceeded(executionTimeLimit)) {
            throw new TimeLimitException(benchmark, queryExecution);
        }
        if (!skipReport) {
            this.statusReporter.reportExecutionFinished(result);
        }
        this.macroService.runBenchmarkMacros(benchmark.getAfterExecutionMacros(), benchmark, connection);
        return result;
    }

    private Connection getConnectionFor(QueryExecution queryExecution) throws SQLException {
        return ((DataSource)this.applicationContext.getBean(queryExecution.getBenchmark().getDataSource(), DataSource.class)).getConnection();
    }

    private boolean isTimeLimitExceeded(Optional<ZonedDateTime> executionTimeLimit) {
        return executionTimeLimit.map(limit -> limit.compareTo(TimeUtils.nowUtc()) < 0).orElse(false);
    }

    static class TimeLimitException
    extends RuntimeException {
        public TimeLimitException(Benchmark benchmark, QueryExecution queryExecution) {
            super(String.format("Query execution exceeded time limit for benchmark %s query %s", benchmark.getName(), queryExecution.getQueryName()));
        }
    }
}

