/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.gateway.rest;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.SecurityOptions;
import org.apache.flink.core.testutils.BlockerSync;
import org.apache.flink.core.testutils.FlinkAssertions;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.rest.HttpMethodWrapper;
import org.apache.flink.runtime.rest.handler.HandlerRequest;
import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
import org.apache.flink.runtime.rest.messages.MessageHeaders;
import org.apache.flink.runtime.rest.messages.MessageParameters;
import org.apache.flink.runtime.rest.messages.RequestBody;
import org.apache.flink.runtime.rest.messages.ResponseBody;
import org.apache.flink.runtime.rest.util.RestClientException;
import org.apache.flink.runtime.rest.versioning.RestAPIVersion;
import org.apache.flink.runtime.rpc.exceptions.EndpointNotStartedException;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler;
import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
import org.apache.flink.table.gateway.api.SqlGatewayService;
import org.apache.flink.table.gateway.rest.SqlGatewayRestEndpoint;
import org.apache.flink.table.gateway.rest.handler.AbstractSqlGatewayRestHandler;
import org.apache.flink.table.gateway.rest.header.SqlGatewayMessageHeaders;
import org.apache.flink.table.gateway.rest.util.SqlGatewayRestAPIVersion;
import org.apache.flink.table.gateway.rest.util.SqlGatewayRestEndpointTestUtils;
import org.apache.flink.table.gateway.rest.util.TestingRestClient;
import org.apache.flink.table.gateway.rest.util.TestingSqlGatewayRestEndpoint;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.concurrent.FutureUtils;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.ThrowingConsumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class SqlGatewayRestEndpointITCase {
    private static final SqlGatewayService SERVICE = null;
    private static SqlGatewayRestEndpoint serverEndpoint;
    private static TestingRestClient restClient;
    private static InetSocketAddress serverAddress;
    private static TestBadCaseHeaders badCaseHeader;
    private static TestBadCaseHandler testHandler;
    private static TestVersionSelectionHeaders0 header0;
    private static TestVersionSelectionHeadersNot0 headerNot0;
    private static TestVersionHandler testVersionHandler0;
    private static TestVersionHandler testVersionHandlerNot0;
    private static Configuration config;
    private static final Duration timeout;

    SqlGatewayRestEndpointITCase() {
    }

    @BeforeEach
    void setup() throws Exception {
        header0 = new TestVersionSelectionHeaders0();
        headerNot0 = new TestVersionSelectionHeadersNot0();
        testVersionHandler0 = new TestVersionHandler(SERVICE, header0);
        testVersionHandlerNot0 = new TestVersionHandler(SERVICE, headerNot0);
        badCaseHeader = new TestBadCaseHeaders();
        testHandler = new TestBadCaseHandler(SERVICE);
        String address = InetAddress.getLoopbackAddress().getHostAddress();
        config = SqlGatewayRestEndpointTestUtils.getBaseConfig(SqlGatewayRestEndpointTestUtils.getFlinkConfig(address, address, "0"));
        serverEndpoint = TestingSqlGatewayRestEndpoint.builder(config, SERVICE).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testHandler).withHandler((RestHandlerSpecification)header0, (ChannelInboundHandler)testVersionHandler0).withHandler((RestHandlerSpecification)headerNot0, (ChannelInboundHandler)testVersionHandlerNot0).buildAndStart();
        restClient = TestingRestClient.getTestingRestClient();
        serverAddress = serverEndpoint.getServerAddress();
    }

    @AfterEach
    void stop() throws Exception {
        if (restClient != null) {
            restClient.shutdown();
            restClient = null;
        }
        if (serverEndpoint != null) {
            serverEndpoint.stop();
            serverEndpoint = null;
        }
    }

    @Test
    void testSqlGatewayMessageHeaders() throws Exception {
        Assertions.assertThatThrownBy(() -> restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)headerNot0, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList(), (RestAPIVersion)SqlGatewayRestAPIVersion.V0)).satisfies(new ThrowingConsumer[]{FlinkAssertions.anyCauseMatches(IllegalArgumentException.class, (String)String.format("The requested version V0 is not supported by the request (method=%s URL=%s). Supported versions are: %s.", headerNot0.getHttpMethod(), headerNot0.getTargetRestEndpointURL(), headerNot0.getSupportedAPIVersions().stream().map(RestAPIVersion::getURLVersionPrefix).collect(Collectors.joining(","))))});
        CompletableFuture specifiedVersionResponse = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)header0, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList(), (RestAPIVersion)SqlGatewayRestAPIVersion.V0);
        TestResponse testResponse0 = (TestResponse)specifiedVersionResponse.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
        Assertions.assertThat((String)testResponse0.getStatus()).isEqualTo("V0");
        CompletableFuture unspecifiedVersionResponse0 = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)header0, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList());
        TestResponse testResponse1 = (TestResponse)unspecifiedVersionResponse0.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
        Assertions.assertThat((String)testResponse1.getStatus()).isEqualTo("V0");
        CompletableFuture unspecifiedVersionResponse1 = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)headerNot0, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList());
        TestResponse testResponse2 = (TestResponse)unspecifiedVersionResponse1.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
        Assertions.assertThat((String)testResponse2.getStatus()).isEqualTo(((SqlGatewayRestAPIVersion)RestAPIVersion.getLatestVersion(headerNot0.getSupportedAPIVersions())).name());
    }

    @Test
    void testVersionSelection() throws Exception {
        for (SqlGatewayRestAPIVersion version : SqlGatewayRestAPIVersion.values()) {
            if (version == SqlGatewayRestAPIVersion.V0) continue;
            CompletableFuture versionResponse = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)headerNot0, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList(), (RestAPIVersion)version);
            TestResponse testResponse = (TestResponse)versionResponse.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
            Assertions.assertThat((String)testResponse.getStatus()).isEqualTo(version.name());
        }
    }

    @Test
    void testDefaultVersionRouting() throws Exception {
        Assertions.assertThat((Boolean)((Boolean)config.get(SecurityOptions.SSL_REST_ENABLED))).isFalse();
        OkHttpClient client = new OkHttpClient();
        Request request = new Request.Builder().url(serverEndpoint.getRestBaseUrl() + header0.getTargetRestEndpointURL()).build();
        Response response = client.newCall(request).execute();
        assert (response.body() != null);
        Assertions.assertThat((String)response.body().string()).contains(new CharSequence[]{SqlGatewayRestAPIVersion.getDefaultVersion().name()});
    }

    @Test
    void testRequestInterleaving() throws Exception {
        BlockerSync sync = new BlockerSync();
        SqlGatewayRestEndpointITCase.testHandler.handlerBody = id -> {
            if (id == 1) {
                try {
                    sync.block();
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
            return CompletableFuture.completedFuture(new TestResponse(id.toString()));
        };
        CompletableFuture<TestResponse> response1 = this.sendRequestToTestHandler(new TestRequest(1));
        sync.awaitBlocker();
        CompletableFuture<TestResponse> response2 = this.sendRequestToTestHandler(new TestRequest(2));
        Assertions.assertThat((String)response2.get().getStatus()).isEqualTo("2");
        sync.releaseBlocker();
        Assertions.assertThat((String)response1.get().getStatus()).isEqualTo("1");
    }

    @Test
    void testDuplicateHandlerRegistrationIsForbidden() {
        Assertions.assertThatThrownBy(() -> {
            try (TestingSqlGatewayRestEndpoint restServerEndpoint = TestingSqlGatewayRestEndpoint.builder(config, SERVICE).withHandler((RestHandlerSpecification)header0, (ChannelInboundHandler)testHandler).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testHandler).build();){
                restServerEndpoint.start();
            }
        }).satisfies(new ThrowingConsumer[]{FlinkAssertions.anyCauseMatches(FlinkRuntimeException.class, (String)"Duplicate REST handler instance found. Please ensure each instance is registered only once.")});
    }

    @Test
    void testHandlerRegistrationOverlappingIsForbidden() {
        Assertions.assertThatThrownBy(() -> {
            try (TestingSqlGatewayRestEndpoint restServerEndpoint = TestingSqlGatewayRestEndpoint.builder(config, SERVICE).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testHandler).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testVersionHandler0).build();){
                restServerEndpoint.start();
            }
        }).satisfies(new ThrowingConsumer[]{FlinkAssertions.anyCauseMatches(FlinkRuntimeException.class, (String)"REST handler registration overlaps with another registration for")});
    }

    @Test
    void testShouldWaitForHandlersWhenClosing() throws Exception {
        SqlGatewayRestEndpointITCase.testHandler.closeFuture = new CompletableFuture();
        BlockerSync sync = new BlockerSync();
        SqlGatewayRestEndpointITCase.testHandler.handlerBody = id -> CompletableFuture.supplyAsync(() -> {
            try {
                sync.block();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            return new TestResponse(id.toString());
        });
        CompletableFuture closeRestServerEndpointFuture = serverEndpoint.closeAsync();
        Assertions.assertThat((CompletableFuture)closeRestServerEndpointFuture).isNotDone();
        CompletableFuture<TestResponse> request = this.sendRequestToTestHandler(new TestRequest(1));
        sync.awaitBlocker();
        SqlGatewayRestEndpointITCase.testHandler.closeFuture.complete(null);
        Assertions.assertThat((CompletableFuture)closeRestServerEndpointFuture).isNotDone();
        sync.releaseBlocker();
        request.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
        closeRestServerEndpointFuture.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
    }

    @Test
    void testOnUnavailableRpcEndpointReturns503() {
        CompletableFuture<TestResponse> response = this.sendRequestToTestHandler(new TestRequest(3));
        Assertions.assertThatThrownBy(response::get).extracting(x -> ExceptionUtils.findThrowable((Throwable)x, RestClientException.class)).extracting(Optional::get).extracting(RestClientException::getHttpResponseStatus).isEqualTo((Object)HttpResponseStatus.SERVICE_UNAVAILABLE);
    }

    private CompletableFuture<TestResponse> sendRequestToTestHandler(TestRequest testRequest) {
        try {
            return restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)badCaseHeader, (MessageParameters)EmptyMessageParameters.getInstance(), testRequest);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static {
        timeout = Duration.ofSeconds(10L);
    }

    private static class TestBadCaseHandler
    extends AbstractSqlGatewayRestHandler<TestRequest, TestResponse, EmptyMessageParameters> {
        private final OneShotLatch closeLatch = new OneShotLatch();
        private CompletableFuture<Void> closeFuture = CompletableFuture.completedFuture(null);
        private Function<Integer, CompletableFuture<TestResponse>> handlerBody;

        TestBadCaseHandler(SqlGatewayService sqlGatewayService) {
            super(sqlGatewayService, Collections.emptyMap(), (MessageHeaders)badCaseHeader);
        }

        public CompletableFuture<Void> closeHandlerAsync() {
            this.closeLatch.trigger();
            return this.closeFuture;
        }

        protected CompletableFuture<TestResponse> handleRequest(@Nullable SqlGatewayRestAPIVersion version, @Nonnull HandlerRequest<TestRequest> request) {
            int id = ((TestRequest)request.getRequestBody()).id;
            if (id == 3) {
                return FutureUtils.completedExceptionally((Throwable)new EndpointNotStartedException("test exception"));
            }
            return this.handlerBody.apply(id);
        }
    }

    private static class TestVersionHandler
    extends AbstractSqlGatewayRestHandler<EmptyRequestBody, TestResponse, EmptyMessageParameters> {
        TestVersionHandler(SqlGatewayService sqlGatewayService, TestVersionSelectionHeadersBase header) {
            super(sqlGatewayService, Collections.emptyMap(), (MessageHeaders)header);
        }

        protected CompletableFuture<TestResponse> handleRequest(@Nullable SqlGatewayRestAPIVersion version, @Nonnull HandlerRequest<EmptyRequestBody> request) {
            assert (version != null);
            return CompletableFuture.completedFuture(new TestResponse(version.name()));
        }
    }

    private static class TestVersionSelectionHeadersNot0
    extends TestVersionSelectionHeadersBase {
        private TestVersionSelectionHeadersNot0() {
        }

        public Collection<SqlGatewayRestAPIVersion> getSupportedAPIVersions() {
            ArrayList<SqlGatewayRestAPIVersion> versions = new ArrayList<SqlGatewayRestAPIVersion>(Arrays.asList(SqlGatewayRestAPIVersion.values()));
            versions.remove(SqlGatewayRestAPIVersion.V0);
            return versions;
        }
    }

    private static class TestVersionSelectionHeaders0
    extends TestVersionSelectionHeadersBase {
        private TestVersionSelectionHeaders0() {
        }

        public Collection<SqlGatewayRestAPIVersion> getSupportedAPIVersions() {
            return Collections.singleton(SqlGatewayRestAPIVersion.V0);
        }
    }

    private static class TestVersionSelectionHeadersBase
    implements SqlGatewayMessageHeaders<EmptyRequestBody, TestResponse, EmptyMessageParameters> {
        private TestVersionSelectionHeadersBase() {
        }

        public Class<EmptyRequestBody> getRequestClass() {
            return EmptyRequestBody.class;
        }

        public HttpMethodWrapper getHttpMethod() {
            return HttpMethodWrapper.GET;
        }

        public String getTargetRestEndpointURL() {
            return "/test/select-version";
        }

        public Class<TestResponse> getResponseClass() {
            return TestResponse.class;
        }

        public HttpResponseStatus getResponseStatusCode() {
            return HttpResponseStatus.OK;
        }

        public String getDescription() {
            return null;
        }

        public EmptyMessageParameters getUnresolvedMessageParameters() {
            return EmptyMessageParameters.getInstance();
        }
    }

    private static class TestBadCaseHeaders
    implements SqlGatewayMessageHeaders<TestRequest, TestResponse, EmptyMessageParameters> {
        private TestBadCaseHeaders() {
        }

        public HttpMethodWrapper getHttpMethod() {
            return HttpMethodWrapper.POST;
        }

        public String getTargetRestEndpointURL() {
            return "/test/";
        }

        public Class<TestRequest> getRequestClass() {
            return TestRequest.class;
        }

        public Class<TestResponse> getResponseClass() {
            return TestResponse.class;
        }

        public HttpResponseStatus getResponseStatusCode() {
            return HttpResponseStatus.OK;
        }

        public String getDescription() {
            return "";
        }

        public EmptyMessageParameters getUnresolvedMessageParameters() {
            return EmptyMessageParameters.getInstance();
        }
    }

    private static class TestResponse
    implements ResponseBody {
        private final String status;

        @JsonCreator
        public TestResponse(@JsonProperty(value="status") String status) {
            this.status = status;
        }

        public String getStatus() {
            return this.status;
        }
    }

    private static class TestRequest
    implements RequestBody {
        public final int id;

        @JsonCreator
        public TestRequest(@JsonProperty(value="id") int id) {
            this.id = id;
        }
    }
}

