// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.http.filter;

import com.yahoo.component.AbstractComponent;
import com.yahoo.component.ComponentId;
import com.yahoo.component.ComponentSpecification;
import com.yahoo.component.chain.Chain;
import com.yahoo.component.chain.ChainedComponent;
import com.yahoo.component.chain.ChainsConfigurer;
import com.yahoo.component.chain.dependencies.Dependencies;
import com.yahoo.component.chain.model.Chainable;
import com.yahoo.component.chain.model.ChainsModel;
import com.yahoo.component.chain.model.ChainsModelBuilder;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.container.core.ChainsConfig;
import com.yahoo.jdisc.http.filter.RequestFilter;
import com.yahoo.jdisc.http.filter.ResponseFilter;
import com.yahoo.jdisc.http.filter.SecurityRequestFilter;
import com.yahoo.jdisc.http.filter.SecurityRequestFilterChain;
import com.yahoo.jdisc.http.filter.SecurityResponseFilter;
import com.yahoo.jdisc.http.filter.SecurityResponseFilterChain;
import com.yahoo.jdisc.http.filter.chain.RequestFilterChain;
import com.yahoo.jdisc.http.filter.chain.ResponseFilterChain;
import com.yahoo.processing.execution.chain.ChainRegistry;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;

import static java.util.List.of;
import static java.util.stream.Collectors.toSet;

/**
 * Creates JDisc request/response filter chains.
 *
 * @author Tony Vaagenes
 * @author bjorncs
 */
public class FilterChainRepository extends AbstractComponent {

    private static final Logger log = Logger.getLogger(FilterChainRepository.class.getName());

    private final ComponentRegistry<Object> filterAndChains;

    public FilterChainRepository(ChainsConfig chainsConfig,
                                 ComponentRegistry<RequestFilter> requestFilters,
                                 ComponentRegistry<ResponseFilter> responseFilters,
                                 ComponentRegistry<SecurityRequestFilter> securityRequestFilters,
                                 ComponentRegistry<SecurityResponseFilter> securityResponseFilters) {
        ComponentRegistry<Object> filterAndChains = new ComponentRegistry<>();
        addAllFilters(filterAndChains, requestFilters, responseFilters, securityRequestFilters, securityResponseFilters);
        addAllChains(filterAndChains, chainsConfig, requestFilters, responseFilters, securityRequestFilters, securityResponseFilters);
        filterAndChains.freeze();
        this.filterAndChains = filterAndChains;
    }

    public Object getFilter(ComponentSpecification componentSpecification) {
        return filterAndChains.getComponent(componentSpecification);
    }

    private static void addAllFilters(ComponentRegistry<Object> destination,
                                      ComponentRegistry<?>... registries) {
        for (ComponentRegistry<?> registry : registries) {
            registry.allComponentsById()
                    .forEach((id, filter) -> destination.register(id, wrapIfSecurityFilter(filter)));
        }
    }

    @SafeVarargs
    private static void addAllChains(ComponentRegistry<Object> destination,
                                     ChainsConfig chainsConfig,
                                     ComponentRegistry<? extends Chainable>... filters) {
        ChainRegistry<FilterWrapper> chainRegistry = buildChainRegistry(chainsConfig, filters);
        chainRegistry.allComponents()
                .forEach(chain -> destination.register(chain.getId(), toJDiscChain(chain)));
    }

    @SafeVarargs
    private static ChainRegistry<FilterWrapper> buildChainRegistry(ChainsConfig chainsConfig,
                                                                   ComponentRegistry<? extends Chainable>... filters) {
        ChainRegistry<FilterWrapper> chainRegistry = new ChainRegistry<>();
        ChainsModel chainsModel = ChainsModelBuilder.buildFromConfig(chainsConfig);
        ChainsConfigurer.prepareChainRegistry(chainRegistry, chainsModel, allFiltersWrapped(filters));
        removeEmptyChains(chainRegistry);
        chainRegistry.freeze();
        return chainRegistry;
    }

    private static void removeEmptyChains(ChainRegistry<FilterWrapper> chainRegistry) {
        chainRegistry.allComponents().stream()
                .filter(chain -> chain.components().isEmpty())
                .map(Chain::getId)
                .peek(id -> log.warning("Removing empty filter chain: " + id))
                .forEach(chainRegistry::unregister);
    }

    @SuppressWarnings("unchecked")
    private static Object toJDiscChain(Chain<FilterWrapper> chain) {
        if (chain.components().isEmpty())
            throw new IllegalArgumentException("Empty filter chain: " + chain.getId());
        checkFilterTypesCompatible(chain);
        List<?> jdiscFilters = chain.components().stream()
                        .map(filterWrapper -> filterWrapper.filter)
                        .toList();
        List<?> wrappedFilters = wrapSecurityFilters(jdiscFilters);
        Object head = wrappedFilters.get(0);
        if (wrappedFilters.size() == 1) return head;
        else if (head instanceof RequestFilter)
            return RequestFilterChain.newInstance((List<RequestFilter>) wrappedFilters);
        else if (head instanceof ResponseFilter)
            return ResponseFilterChain.newInstance((List<ResponseFilter>) wrappedFilters);
        throw new IllegalStateException();
    }

    private static List<?> wrapSecurityFilters(List<?> filters) {
        List<Object> aggregatedSecurityFilters = new ArrayList<>();
        List<Object> wrappedFilters = new ArrayList<>();
        for (Object filter : filters) {
            if (isSecurityFilter(filter)) {
                aggregatedSecurityFilters.add(filter);
            } else {
                if (!aggregatedSecurityFilters.isEmpty()) {
                    wrappedFilters.add(createSecurityChain(aggregatedSecurityFilters));
                    aggregatedSecurityFilters.clear();
                }
                wrappedFilters.add(filter);
            }
        }
        if (!aggregatedSecurityFilters.isEmpty()) {
            wrappedFilters.add(createSecurityChain(aggregatedSecurityFilters));
        }
        return wrappedFilters;
    }

    private static void checkFilterTypesCompatible(Chain<FilterWrapper> chain) {
        Set<ComponentId> requestFilters = chain.components().stream()
                .filter(filter -> filter instanceof RequestFilter || filter instanceof SecurityRequestFilter)
                .map(FilterWrapper::getId)
                .collect(toSet());
        Set<ComponentId> responseFilters = chain.components().stream()
                .filter(filter -> filter instanceof ResponseFilter || filter instanceof SecurityResponseFilter)
                .map(FilterWrapper::getId)
                .collect(toSet());
        if (!requestFilters.isEmpty() && !responseFilters.isEmpty()) {
            throw new IllegalArgumentException(
                    String.format(
                            "Can't mix request and response filters in chain %s: request filters: %s, response filters: %s.",
                            chain.getId(), requestFilters, responseFilters));
        }
    }

    @SafeVarargs
    private static ComponentRegistry<FilterWrapper> allFiltersWrapped(ComponentRegistry<? extends Chainable>... registries) {
        ComponentRegistry<FilterWrapper> wrappedFilters = new ComponentRegistry<>();
        for (ComponentRegistry<? extends Chainable> registry : registries) {
            registry.allComponentsById()
                    .forEach((id, filter) -> wrappedFilters.register(id, new FilterWrapper(id, filter)));
        }
        wrappedFilters.freeze();
        return wrappedFilters;
    }

    private static Object wrapIfSecurityFilter(Object filter) {
        if (isSecurityFilter(filter)) return createSecurityChain(List.of(filter));
        return filter;
    }

    @SuppressWarnings("unchecked")
    private static Object createSecurityChain(List<?> filters) {
        Object head = filters.get(0);
        if (head instanceof SecurityRequestFilter)
            return SecurityRequestFilterChain.newInstance((List<SecurityRequestFilter>) filters);
        else if (head instanceof SecurityResponseFilter)
            return SecurityResponseFilterChain.newInstance((List<SecurityResponseFilter>) filters);
        throw new IllegalArgumentException("Unexpected class " + head.getClass());
    }

    private static boolean isSecurityFilter(Object filter) {
        return filter instanceof SecurityRequestFilter || filter instanceof SecurityResponseFilter;
    }

    private static class FilterWrapper extends ChainedComponent {
        public final Chainable filter;
        public final Class<? extends Chainable> filterType;

        public FilterWrapper(ComponentId id, Chainable filter) {
            super(id);
            this.filter = filter;
            this.filterType = getFilterType(filter);
        }

        @Override
        public Dependencies getAnnotatedDependencies() {
            return filter == null ? super.getAnnotatedDependencies() : filter.getAnnotatedDependencies();
        }

        private static Class<? extends Chainable> getFilterType(Object filter) {
            if (filter instanceof RequestFilter)
                return RequestFilter.class;
            else if (filter instanceof ResponseFilter)
                return ResponseFilter.class;
            else if (filter instanceof SecurityRequestFilter)
                return SecurityRequestFilter.class;
            else if (filter instanceof SecurityResponseFilter)
                return SecurityResponseFilter.class;
            throw new IllegalArgumentException("Unsupported filter type: " + filter.getClass().getName());
        }
    }

}
