/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with
 * the License. A copy of the License is located at
 * 
 * http://aws.amazon.com/apache2.0
 * 
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package software.amazon.awssdk.services.sagemakerruntime;

import java.util.Collections;
import java.util.List;
import software.amazon.awssdk.annotations.Generated;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata;
import software.amazon.awssdk.awscore.internal.AwsServiceProtocol;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.client.handler.ClientExecutionParams;
import software.amazon.awssdk.core.client.handler.SyncClientHandler;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.http.HttpResponseHandler;
import software.amazon.awssdk.core.metrics.CoreMetric;
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.metrics.MetricPublisher;
import software.amazon.awssdk.metrics.NoOpMetricCollector;
import software.amazon.awssdk.protocols.core.ExceptionMetadata;
import software.amazon.awssdk.protocols.json.AwsJsonProtocol;
import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory;
import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory;
import software.amazon.awssdk.protocols.json.JsonOperationMetadata;
import software.amazon.awssdk.services.sagemakerruntime.internal.SageMakerRuntimeServiceClientConfigurationBuilder;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalStreamFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointAsyncRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointAsyncResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelStreamErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException;
import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException;
import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointAsyncRequestMarshaller;
import software.amazon.awssdk.services.sagemakerruntime.transform.InvokeEndpointRequestMarshaller;
import software.amazon.awssdk.utils.Logger;

/**
 * Internal implementation of {@link SageMakerRuntimeClient}.
 *
 * @see SageMakerRuntimeClient#builder()
 */
@Generated("software.amazon.awssdk:codegen")
@SdkInternalApi
final class DefaultSageMakerRuntimeClient implements SageMakerRuntimeClient {
    private static final Logger log = Logger.loggerFor(DefaultSageMakerRuntimeClient.class);

    private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder()
            .serviceProtocol(AwsServiceProtocol.REST_JSON).build();

    private final SyncClientHandler clientHandler;

    private final AwsJsonProtocolFactory protocolFactory;

    private final SdkClientConfiguration clientConfiguration;

    private final SageMakerRuntimeServiceClientConfiguration serviceClientConfiguration;

    protected DefaultSageMakerRuntimeClient(SageMakerRuntimeServiceClientConfiguration serviceClientConfiguration,
            SdkClientConfiguration clientConfiguration) {
        this.clientHandler = new AwsSyncClientHandler(clientConfiguration);
        this.clientConfiguration = clientConfiguration;
        this.serviceClientConfiguration = serviceClientConfiguration;
        this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build();
    }

    /**
     * <p>
     * After you deploy a model into production using Amazon SageMaker hosting services, your client applications use
     * this API to get inferences from the model hosted at the specified endpoint.
     * </p>
     * <p>
     * For an overview of Amazon SageMaker, see <a
     * href="https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works.html">How It Works</a>.
     * </p>
     * <p>
     * Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional
     * headers. You should not rely on the behavior of headers outside those enumerated in the request syntax.
     * </p>
     * <p>
     * Calls to <code>InvokeEndpoint</code> are authenticated by using Amazon Web Services Signature Version 4. For
     * information, see <a
     * href="https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html">Authenticating
     * Requests (Amazon Web Services Signature Version 4)</a> in the <i>Amazon S3 API Reference</i>.
     * </p>
     * <p>
     * A customer's model containers must respond to requests within 60 seconds. The model itself can have a maximum
     * processing time of 60 seconds before responding to invocations. If your model is going to take 50-60 seconds of
     * processing time, the SDK socket timeout should be set to be 70 seconds.
     * </p>
     * <note>
     * <p>
     * Endpoints are scoped to an individual account, and are not public. The URL does not contain the account ID, but
     * Amazon SageMaker determines the account ID from the authentication token that is supplied by the caller.
     * </p>
     * </note>
     *
     * @param invokeEndpointRequest
     * @return Result of the InvokeEndpoint operation returned by the service.
     * @throws InternalFailureException
     *         An internal failure occurred.
     * @throws ServiceUnavailableException
     *         The service is unavailable. Try your call again.
     * @throws ValidationErrorException
     *         Inspect your request and try again.
     * @throws ModelErrorException
     *         Model (owned by the customer in the container) returned 4xx or 5xx error code.
     * @throws InternalDependencyException
     *         Your request caused an exception with an internal dependency. Contact customer support.
     * @throws ModelNotReadyException
     *         Either a serverless endpoint variant's resources are still being provisioned, or a multi-model endpoint
     *         is still downloading or loading the target model. Wait and try your request again.
     * @throws SdkException
     *         Base class for all exceptions that can be thrown by the SDK (both service and client). Can be used for
     *         catch all scenarios.
     * @throws SdkClientException
     *         If any client side error occurs such as an IO related failure, failure to get credentials, etc.
     * @throws SageMakerRuntimeException
     *         Base class for all service exceptions. Unknown exceptions will be thrown as an instance of this type.
     * @sample SageMakerRuntimeClient.InvokeEndpoint
     * @see <a href="https://docs.aws.amazon.com/goto/WebAPI/runtime.sagemaker-2017-05-13/InvokeEndpoint"
     *      target="_top">AWS API Documentation</a>
     */
    @Override
    public InvokeEndpointResponse invokeEndpoint(InvokeEndpointRequest invokeEndpointRequest) throws InternalFailureException,
            ServiceUnavailableException, ValidationErrorException, ModelErrorException, InternalDependencyException,
            ModelNotReadyException, AwsServiceException, SdkClientException, SageMakerRuntimeException {
        JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false)
                .isPayloadJson(false).build();

        HttpResponseHandler<InvokeEndpointResponse> responseHandler = protocolFactory.createResponseHandler(operationMetadata,
                InvokeEndpointResponse::builder);

        HttpResponseHandler<AwsServiceException> errorResponseHandler = createErrorResponseHandler(protocolFactory,
                operationMetadata);
        SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointRequest, this.clientConfiguration);
        List<MetricPublisher> metricPublishers = resolveMetricPublishers(clientConfiguration, invokeEndpointRequest
                .overrideConfiguration().orElse(null));
        MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector
                .create("ApiCall");
        try {
            apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime");
            apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpoint");

            return clientHandler.execute(new ClientExecutionParams<InvokeEndpointRequest, InvokeEndpointResponse>()
                    .withOperationName("InvokeEndpoint").withProtocolMetadata(protocolMetadata)
                    .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
                    .withRequestConfiguration(clientConfiguration).withInput(invokeEndpointRequest)
                    .withMetricCollector(apiCallMetricCollector)
                    .withMarshaller(new InvokeEndpointRequestMarshaller(protocolFactory)));
        } finally {
            metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
        }
    }

    /**
     * <p>
     * After you deploy a model into production using Amazon SageMaker hosting services, your client applications use
     * this API to get inferences from the model hosted at the specified endpoint in an asynchronous manner.
     * </p>
     * <p>
     * Inference requests sent to this API are enqueued for asynchronous processing. The processing of the inference
     * request may or may not complete before you receive a response from this API. The response from this API will not
     * contain the result of the inference request but contain information about where you can locate it.
     * </p>
     * <p>
     * Amazon SageMaker strips all POST headers except those supported by the API. Amazon SageMaker might add additional
     * headers. You should not rely on the behavior of headers outside those enumerated in the request syntax.
     * </p>
     * <p>
     * Calls to <code>InvokeEndpointAsync</code> are authenticated by using Amazon Web Services Signature Version 4. For
     * information, see <a
     * href="https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html">Authenticating
     * Requests (Amazon Web Services Signature Version 4)</a> in the <i>Amazon S3 API Reference</i>.
     * </p>
     *
     * @param invokeEndpointAsyncRequest
     * @return Result of the InvokeEndpointAsync operation returned by the service.
     * @throws InternalFailureException
     *         An internal failure occurred.
     * @throws ServiceUnavailableException
     *         The service is unavailable. Try your call again.
     * @throws ValidationErrorException
     *         Inspect your request and try again.
     * @throws SdkException
     *         Base class for all exceptions that can be thrown by the SDK (both service and client). Can be used for
     *         catch all scenarios.
     * @throws SdkClientException
     *         If any client side error occurs such as an IO related failure, failure to get credentials, etc.
     * @throws SageMakerRuntimeException
     *         Base class for all service exceptions. Unknown exceptions will be thrown as an instance of this type.
     * @sample SageMakerRuntimeClient.InvokeEndpointAsync
     * @see <a href="https://docs.aws.amazon.com/goto/WebAPI/runtime.sagemaker-2017-05-13/InvokeEndpointAsync"
     *      target="_top">AWS API Documentation</a>
     */
    @Override
    public InvokeEndpointAsyncResponse invokeEndpointAsync(InvokeEndpointAsyncRequest invokeEndpointAsyncRequest)
            throws InternalFailureException, ServiceUnavailableException, ValidationErrorException, AwsServiceException,
            SdkClientException, SageMakerRuntimeException {
        JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false)
                .isPayloadJson(true).build();

        HttpResponseHandler<InvokeEndpointAsyncResponse> responseHandler = protocolFactory.createResponseHandler(
                operationMetadata, InvokeEndpointAsyncResponse::builder);

        HttpResponseHandler<AwsServiceException> errorResponseHandler = createErrorResponseHandler(protocolFactory,
                operationMetadata);
        SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(invokeEndpointAsyncRequest,
                this.clientConfiguration);
        List<MetricPublisher> metricPublishers = resolveMetricPublishers(clientConfiguration, invokeEndpointAsyncRequest
                .overrideConfiguration().orElse(null));
        MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector
                .create("ApiCall");
        try {
            apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SageMaker Runtime");
            apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "InvokeEndpointAsync");

            return clientHandler.execute(new ClientExecutionParams<InvokeEndpointAsyncRequest, InvokeEndpointAsyncResponse>()
                    .withOperationName("InvokeEndpointAsync").withProtocolMetadata(protocolMetadata)
                    .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
                    .withRequestConfiguration(clientConfiguration).withInput(invokeEndpointAsyncRequest)
                    .withMetricCollector(apiCallMetricCollector)
                    .withMarshaller(new InvokeEndpointAsyncRequestMarshaller(protocolFactory)));
        } finally {
            metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
        }
    }

    @Override
    public final String serviceName() {
        return SERVICE_NAME;
    }

    private static List<MetricPublisher> resolveMetricPublishers(SdkClientConfiguration clientConfiguration,
            RequestOverrideConfiguration requestOverrideConfiguration) {
        List<MetricPublisher> publishers = null;
        if (requestOverrideConfiguration != null) {
            publishers = requestOverrideConfiguration.metricPublishers();
        }
        if (publishers == null || publishers.isEmpty()) {
            publishers = clientConfiguration.option(SdkClientOption.METRIC_PUBLISHERS);
        }
        if (publishers == null) {
            publishers = Collections.emptyList();
        }
        return publishers;
    }

    private HttpResponseHandler<AwsServiceException> createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory,
            JsonOperationMetadata operationMetadata) {
        return protocolFactory.createErrorResponseHandler(operationMetadata);
    }

    private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, SdkClientConfiguration clientConfiguration) {
        List<SdkPlugin> plugins = request.overrideConfiguration().map(c -> c.plugins()).orElse(Collections.emptyList());
        if (plugins.isEmpty()) {
            return clientConfiguration;
        }
        SageMakerRuntimeServiceClientConfigurationBuilder.BuilderInternal serviceConfigBuilder = SageMakerRuntimeServiceClientConfigurationBuilder
                .builder(clientConfiguration.toBuilder());
        serviceConfigBuilder.overrideConfiguration(serviceClientConfiguration.overrideConfiguration());
        for (SdkPlugin plugin : plugins) {
            plugin.configureClient(serviceConfigBuilder);
        }
        return serviceConfigBuilder.buildSdkClientConfiguration();
    }

    private <T extends BaseAwsJsonProtocolFactory.Builder<T>> T init(T builder) {
        return builder
                .clientConfiguration(clientConfiguration)
                .defaultServiceExceptionSupplier(SageMakerRuntimeException::builder)
                .protocol(AwsJsonProtocol.REST_JSON)
                .protocolVersion("1.1")
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("ModelStreamError")
                                .exceptionBuilderSupplier(ModelStreamErrorException::builder).build())
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("ModelError")
                                .exceptionBuilderSupplier(ModelErrorException::builder).httpStatusCode(424).build())
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("InternalStreamFailure")
                                .exceptionBuilderSupplier(InternalStreamFailureException::builder).build())
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("ValidationError")
                                .exceptionBuilderSupplier(ValidationErrorException::builder).httpStatusCode(400).build())
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("ServiceUnavailable")
                                .exceptionBuilderSupplier(ServiceUnavailableException::builder).httpStatusCode(503).build())
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("InternalFailure")
                                .exceptionBuilderSupplier(InternalFailureException::builder).httpStatusCode(500).build())
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("InternalDependencyException")
                                .exceptionBuilderSupplier(InternalDependencyException::builder).httpStatusCode(530).build())
                .registerModeledException(
                        ExceptionMetadata.builder().errorCode("ModelNotReadyException")
                                .exceptionBuilderSupplier(ModelNotReadyException::builder).httpStatusCode(429).build());
    }

    @Override
    public final SageMakerRuntimeServiceClientConfiguration serviceClientConfiguration() {
        return this.serviceClientConfiguration;
    }

    @Override
    public void close() {
        clientHandler.close();
    }
}
