/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.guardrail;

import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.guardrail.Guardrail;
import dev.langchain4j.guardrail.GuardrailRequestParams;
import dev.langchain4j.guardrail.InputGuardrail;
import dev.langchain4j.guardrail.InputGuardrailException;
import dev.langchain4j.guardrail.InputGuardrailExecutor;
import dev.langchain4j.guardrail.InputGuardrailRequest;
import dev.langchain4j.guardrail.InputGuardrailResult;
import dev.langchain4j.guardrail.config.GuardrailsConfig;
import dev.langchain4j.guardrail.config.InputGuardrailsConfig;
import dev.langchain4j.invocation.InvocationContext;
import dev.langchain4j.test.guardrail.GuardrailAssertions;
import java.util.Map;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.aggregator.ArgumentsAccessor;
import org.junit.jupiter.params.aggregator.ArgumentsAggregationException;
import org.junit.jupiter.params.aggregator.ArgumentsAggregator;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;

class InputGuardrailExecutorTests {
    private static final InvocationContext DEFAULT_INVOCATION_CONTEXT = InvocationContext.builder().interfaceName("SomeInterface").methodName("someMethod").methodArgument((Object)"one").methodArgument((Object)"two").chatMemoryId((Object)"one").build();

    InputGuardrailExecutorTests() {
    }

    @ParameterizedTest(name="{0}")
    @MethodSource(value={"successGuardrails"})
    void allSuccessfulGuardrails(String testDesc, int howManyShouldExecute, InputGuardrail ... guardrails) {
        InputGuardrail[] spiedGuardrails = (InputGuardrail[])Stream.of(guardrails).map(Mockito::spy).toArray(InputGuardrail[]::new);
        InputGuardrailRequest request = InputGuardrailExecutorTests.from(UserMessage.from((String)"test"));
        InputGuardrailExecutor executor = ((InputGuardrailExecutor.InputGuardrailExecutorBuilder)InputGuardrailExecutor.builder().guardrails((Guardrail[])spiedGuardrails)).build();
        InputGuardrailResult result = executor.execute(request);
        GuardrailAssertions.assertThat(result).isSuccessful();
        IntStream.range(0, howManyShouldExecute).mapToObj(i -> (SuccessInputGuardrail)spiedGuardrails[i]).forEach(guardrail -> {
            GuardrailAssertions.assertThat((boolean)guardrail.shouldBeExecuted).isTrue();
            ((SuccessInputGuardrail)Mockito.verify((Object)guardrail)).validate(request);
        });
        IntStream.range(howManyShouldExecute, spiedGuardrails.length).mapToObj(i -> (SuccessInputGuardrail)spiedGuardrails[i]).forEach(guardrail -> {
            GuardrailAssertions.assertThat((boolean)guardrail.shouldBeExecuted).isFalse();
            ((SuccessInputGuardrail)Mockito.verify((Object)guardrail, (VerificationMode)Mockito.never())).validate(request);
        });
    }

    @Test
    void noGuardrails() {
        InputGuardrailRequest request = InputGuardrailExecutorTests.from(UserMessage.from((String)"test"));
        InputGuardrailExecutor executor = InputGuardrailExecutor.builder().build();
        InputGuardrailResult result = executor.execute(request);
        GuardrailAssertions.assertThat(result).isSuccessful();
    }

    @ParameterizedTest(name="{0}")
    @MethodSource(value={"failedFatalGuardrails"})
    void failedFatal(String testDesc, int howManyShouldExecute, int howManyFailures, InputGuardrail ... guardrails) {
        InputGuardrail[] spiedGuardrails = (InputGuardrail[])Stream.of(guardrails).map(Mockito::spy).toArray(InputGuardrail[]::new);
        InputGuardrailRequest request = InputGuardrailExecutorTests.from(UserMessage.from((String)"test"));
        InputGuardrailExecutor executor = ((InputGuardrailExecutor.InputGuardrailExecutorBuilder)((InputGuardrailExecutor.InputGuardrailExecutorBuilder)InputGuardrailExecutor.builder().guardrails((Guardrail[])spiedGuardrails)).config((GuardrailsConfig)((InputGuardrailsConfig)InputGuardrailsConfig.builder().build()))).build();
        Assertions.assertThatExceptionOfType(InputGuardrailException.class).isThrownBy(() -> executor.execute(request)).withMessageMatching("The guardrail " + this.getClass().getName() + "\\$.+Guardrail failed with this message: failure \\d");
        IntStream.range(0, howManyShouldExecute).mapToObj(i -> spiedGuardrails[i]).forEach(guardrail -> {
            boolean bl;
            if (guardrail instanceof SuccessInputGuardrail) {
                SuccessInputGuardrail s = (SuccessInputGuardrail)guardrail;
                bl = s.shouldBeExecuted;
            } else {
                bl = ((FailureInputGuardrail)guardrail).shouldBeExecuted;
            }
            boolean shouldBeExecuted = bl;
            GuardrailAssertions.assertThat((boolean)shouldBeExecuted).isTrue();
            ((InputGuardrail)Mockito.verify((Object)guardrail)).validate(request);
        });
        IntStream.range(howManyShouldExecute, spiedGuardrails.length).mapToObj(i -> spiedGuardrails[i]).forEach(guardrail -> {
            boolean bl;
            if (guardrail instanceof SuccessInputGuardrail) {
                SuccessInputGuardrail s = (SuccessInputGuardrail)guardrail;
                bl = s.shouldBeExecuted;
            } else {
                bl = ((FailureInputGuardrail)guardrail).shouldBeExecuted;
            }
            boolean shouldBeExecuted = bl;
            GuardrailAssertions.assertThat((boolean)shouldBeExecuted).isFalse();
            ((InputGuardrail)Mockito.verify((Object)guardrail, (VerificationMode)Mockito.never())).validate(request);
        });
        long numFailedGuardrails = Stream.of(spiedGuardrails).filter(FailureInputGuardrail.class::isInstance).map(FailureInputGuardrail.class::cast).filter(guardrail -> guardrail.shouldBeExecuted).count();
        GuardrailAssertions.assertThat((long)numFailedGuardrails).isEqualTo((long)howManyFailures);
    }

    static Stream<Arguments> successGuardrails() {
        return Stream.of(Arguments.of((Object[])new Object[]{"No guardrails", 0}), Arguments.of((Object[])new Object[]{"One successful guardrail", 1, new SuccessInputGuardrail()}), Arguments.of((Object[])new Object[]{"Two successful guardrails", 2, new SuccessInputGuardrail(), new SuccessInputGuardrail()}), Arguments.of((Object[])new Object[]{"Three successful guardrails", 3, new SuccessInputGuardrail(), new SuccessInputGuardrail(), new SuccessInputGuardrail()}));
    }

    static Stream<Arguments> failedFatalGuardrails() {
        return Stream.of(Arguments.of((Object[])new Object[]{"One successful one fatal guardrail", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1)}), Arguments.of((Object[])new Object[]{"One fatal one successful guardrail", 1, 1, new FatalInputGuardrail(1), new SuccessInputGuardrail(false)}), Arguments.of((Object[])new Object[]{"One successful one fatal one successful guardrails", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1), new SuccessInputGuardrail(false)}), Arguments.of((Object[])new Object[]{"One successful one fatal one failed guardrails", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1), new FailureInputGuardrail(2).shouldNotBeExecuted()}), Arguments.of((Object[])new Object[]{"One failure one successful guardrail", 2, 1, new FailureInputGuardrail(1), new SuccessInputGuardrail()}), Arguments.of((Object[])new Object[]{"One successful one failure one successful guardrails", 3, 1, new SuccessInputGuardrail(), new FailureInputGuardrail(1), new SuccessInputGuardrail()}), Arguments.of((Object[])new Object[]{"One successful one fatal one failure guardrails", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1), new FailureInputGuardrail(2).shouldNotBeExecuted()}), Arguments.of((Object[])new Object[]{"Two failure guardrails", 2, 2, new FailureInputGuardrail(1), new FailureInputGuardrail(2)}), Arguments.of((Object[])new Object[]{"One successful one failure one fatal one failure guardrails", 3, 2, new SuccessInputGuardrail(), new FailureInputGuardrail(2), new FatalInputGuardrail(1), new FailureInputGuardrail(3).shouldNotBeExecuted()}));
    }

    public static InputGuardrailRequest from(UserMessage userMessage) {
        GuardrailRequestParams newCommonParams = GuardrailRequestParams.builder().chatMemory(null).augmentationResult(null).userMessageTemplate("").variables(Map.of()).invocationContext(DEFAULT_INVOCATION_CONTEXT).build();
        return InputGuardrailRequest.builder().userMessage(userMessage).commonParams(newCommonParams).build();
    }

    private static class FailureInputGuardrail<G extends FailureInputGuardrail>
    implements InputGuardrail {
        protected final String failureMessage;
        private boolean shouldBeExecuted = true;

        private FailureInputGuardrail(int failureNumber) {
            this("failure " + failureNumber);
        }

        private FailureInputGuardrail(String failureMessage) {
            this.failureMessage = failureMessage;
        }

        G shouldNotBeExecuted() {
            this.shouldBeExecuted = false;
            return (G)this;
        }

        public InputGuardrailResult validate(UserMessage userMessage) {
            return this.failure(this.failureMessage);
        }
    }

    private static class SuccessInputGuardrail
    implements InputGuardrail {
        private boolean shouldBeExecuted = true;

        SuccessInputGuardrail(boolean shouldBeExecuted) {
            this.shouldBeExecuted = shouldBeExecuted;
        }

        SuccessInputGuardrail() {
            this(true);
        }

        public InputGuardrailResult validate(UserMessage userMessage) {
            return InputGuardrailResult.success();
        }
    }

    private static class FatalInputGuardrail
    extends FailureInputGuardrail<FatalInputGuardrail> {
        private FatalInputGuardrail(int failureNumber) {
            super(failureNumber);
        }

        @Override
        public InputGuardrailResult validate(UserMessage userMessage) {
            return this.fatal(this.failureMessage);
        }
    }

    static class InputGuardrailAggregator
    implements ArgumentsAggregator {
        InputGuardrailAggregator() {
        }

        public Object aggregateArguments(ArgumentsAccessor accessor, ParameterContext context) throws ArgumentsAggregationException {
            return accessor.toList().stream().skip(context.getIndex()).map(InputGuardrail.class::cast).toArray(InputGuardrail[]::new);
        }
    }
}

