/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.search.ranking;

import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.Model;
import com.yahoo.search.ranking.DummyEvaluator;
import com.yahoo.search.ranking.Evaluator;
import com.yahoo.search.ranking.FunEvalSpec;
import com.yahoo.search.ranking.LinearNormalizer;
import com.yahoo.search.ranking.MatchFeatureInput;
import com.yahoo.search.ranking.Normalizer;
import com.yahoo.search.ranking.NormalizerSetup;
import com.yahoo.search.ranking.RankProfilesEvaluator;
import com.yahoo.search.ranking.ReciprocalRankNormalizer;
import com.yahoo.search.ranking.SimpleEvaluator;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;

class GlobalPhaseSetup {
    final FunEvalSpec globalPhaseEvalSpec;
    final int rerankCount;
    final double rankScoreDropLimit;
    final Collection<String> matchFeaturesToHide;
    final List<NormalizerSetup> normalizers;
    final Map<String, Tensor> defaultValues;

    GlobalPhaseSetup(FunEvalSpec globalPhaseEvalSpec, int rerankCount, double rankScoreDropLimit, Collection<String> matchFeaturesToHide, List<NormalizerSetup> normalizers, Map<String, Tensor> defaultValues) {
        this.globalPhaseEvalSpec = globalPhaseEvalSpec;
        this.rerankCount = rerankCount;
        this.rankScoreDropLimit = rankScoreDropLimit;
        this.matchFeaturesToHide = matchFeaturesToHide;
        this.normalizers = normalizers;
        this.defaultValues = defaultValues;
    }

    private static Map<String, Tensor> extraDefaultQueryFeatureValues(RankProfilesConfig.Rankprofile rp, List<String> fromQuery, List<NormalizerSetup> normalizers) {
        HashMap<String, DefaultQueryFeatureExtractor> extractors = new HashMap<String, DefaultQueryFeatureExtractor>();
        for (String string : fromQuery) {
            extractors.put(string, new DefaultQueryFeatureExtractor(string));
        }
        for (NormalizerSetup normalizerSetup : normalizers) {
            for (String string : normalizerSetup.inputEvalSpec().fromQuery()) {
                extractors.put(string, new DefaultQueryFeatureExtractor(string));
            }
        }
        HashMap<String, DefaultQueryFeatureExtractor> targets = new HashMap<String, DefaultQueryFeatureExtractor>();
        for (DefaultQueryFeatureExtractor extractor : extractors.values()) {
            for (String key : extractor.lookingFor()) {
                DefaultQueryFeatureExtractor old = targets.put(key, extractor);
                if (old == null) continue;
                throw new IllegalStateException("Multiple targets for key: " + key);
            }
        }
        for (RankProfilesConfig.Rankprofile.Fef.Property prop : rp.fef().property()) {
            DefaultQueryFeatureExtractor defaultQueryFeatureExtractor = (DefaultQueryFeatureExtractor)targets.get(prop.name());
            if (defaultQueryFeatureExtractor == null) continue;
            defaultQueryFeatureExtractor.accept(prop.name(), prop.value());
        }
        HashMap<String, Tensor> hashMap = new HashMap<String, Tensor>();
        for (DefaultQueryFeatureExtractor defaultQueryFeatureExtractor : extractors.values()) {
            hashMap.put(defaultQueryFeatureExtractor.qfName, defaultQueryFeatureExtractor.extract());
        }
        return hashMap;
    }

    static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) {
        Model model = modelEvaluator.modelForRankProfile(rp.name());
        HashMap<String, RankProfilesConfig.Rankprofile.Normalizer> availableNormalizers = new HashMap<String, RankProfilesConfig.Rankprofile.Normalizer>();
        for (RankProfilesConfig.Rankprofile.Normalizer n : rp.normalizer()) {
            availableNormalizers.put(n.name(), n);
        }
        Supplier<FunctionEvaluator> functionEvaluatorSource = null;
        int rerankCount = -1;
        double rankScoreDropLimit = -1.7976931348623157E308;
        HashSet<String> namesToHide = new HashSet<String>();
        HashSet<String> matchFeatures = new HashSet<String>();
        HashMap<String, String> renameFeatures = new HashMap<String, String>();
        String renameFrom = null;
        for (RankProfilesConfig.Rankprofile.Fef.Property prop : rp.fef().property()) {
            if (prop.name().equals("vespa.globalphase.rerankcount")) {
                rerankCount = Integer.parseInt(prop.value());
            }
            if (prop.name().equals("vespa.globalphase.rankscoredroplimit")) {
                rankScoreDropLimit = Double.parseDouble(prop.value());
            }
            if (prop.name().equals("vespa.rank.globalphase")) {
                functionEvaluatorSource = () -> model.evaluatorOf(new String[]{"globalphase"});
            }
            if (prop.name().equals("vespa.hidden.matchfeature")) {
                namesToHide.add(prop.value());
            }
            if (prop.name().equals("vespa.match.feature")) {
                matchFeatures.add(prop.value());
            }
            if (!prop.name().equals("vespa.feature.rename")) continue;
            if (renameFrom == null) {
                renameFrom = prop.value();
                continue;
            }
            renameFeatures.put(renameFrom, prop.value());
            renameFrom = null;
        }
        renameFeatures.entrySet().removeIf(entry -> !matchFeatures.contains(entry.getKey()));
        if (rerankCount < 0) {
            rerankCount = 100;
        }
        if (functionEvaluatorSource != null) {
            InputResolver mainResolver = new InputResolver(matchFeatures, renameFeatures, availableNormalizers.keySet());
            FunctionEvaluator evaluator = (FunctionEvaluator)functionEvaluatorSource.get();
            List<String> allInputs = List.copyOf(evaluator.function().arguments());
            mainResolver.resolve(allInputs);
            ArrayList<NormalizerSetup> normalizers = new ArrayList<NormalizerSetup>();
            for (String input : mainResolver.usedNormalizers) {
                Supplier<Evaluator> normSource;
                RankProfilesConfig.Rankprofile.Normalizer cfg = (RankProfilesConfig.Rankprofile.Normalizer)availableNormalizers.get(input);
                String normInput = cfg.input();
                if (matchFeatures.contains(normInput) || renameFeatures.containsValue(normInput)) {
                    normSource = () -> new DummyEvaluator(normInput);
                    normalizers.add(GlobalPhaseSetup.makeNormalizerSetup(cfg, matchFeatures, renameFeatures, normSource, List.of(normInput), rerankCount));
                    continue;
                }
                normSource = () -> model.evaluatorOf(new String[]{normInput});
                List<String> normInputs = List.copyOf(((FunctionEvaluator)normSource.get()).function().arguments());
                Supplier<Evaluator> normSupplier = SimpleEvaluator.wrap(normSource);
                normalizers.add(GlobalPhaseSetup.makeNormalizerSetup(cfg, matchFeatures, renameFeatures, normSupplier, normInputs, rerankCount));
            }
            Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource);
            FunEvalSpec gfun = new FunEvalSpec(supplier, mainResolver.fromQuery, mainResolver.fromMF);
            Map<String, Tensor> defaultValues = GlobalPhaseSetup.extraDefaultQueryFeatureValues(rp, mainResolver.fromQuery, normalizers);
            return new GlobalPhaseSetup(gfun, rerankCount, rankScoreDropLimit, namesToHide, normalizers, defaultValues);
        }
        return null;
    }

    private static NormalizerSetup makeNormalizerSetup(RankProfilesConfig.Rankprofile.Normalizer cfg, Set<String> matchFeatures, Map<String, String> renamedFeatures, Supplier<Evaluator> evalSupplier, List<String> normInputs, int rerankCount) {
        InputResolver normResolver = new InputResolver(matchFeatures, renamedFeatures, Set.of());
        normResolver.resolve(normInputs);
        FunEvalSpec fun = new FunEvalSpec(evalSupplier, normResolver.fromQuery, normResolver.fromMF);
        return new NormalizerSetup(cfg.name(), GlobalPhaseSetup.makeNormalizerSupplier(cfg, rerankCount), fun);
    }

    private static Supplier<Normalizer> makeNormalizerSupplier(RankProfilesConfig.Rankprofile.Normalizer cfg, int rerankCount) {
        return switch (cfg.algo()) {
            default -> throw new IncompatibleClassChangeError();
            case RankProfilesConfig.Rankprofile.Normalizer.Algo.Enum.LINEAR -> () -> new LinearNormalizer(rerankCount);
            case RankProfilesConfig.Rankprofile.Normalizer.Algo.Enum.RRANK -> () -> new ReciprocalRankNormalizer(rerankCount, cfg.kparam());
        };
    }

    static String asQueryFeature(String input) {
        Reference ref;
        Optional optRef = Reference.simple((String)input);
        if (optRef.isPresent() && (ref = (Reference)optRef.get()).isSimple() && ref.name().equals("query")) {
            return (String)ref.simpleArgument().get();
        }
        return null;
    }

    static class DefaultQueryFeatureExtractor {
        final String baseName;
        final String qfName;
        TensorType type = null;
        Tensor value = null;

        DefaultQueryFeatureExtractor(String unwrappedQueryFeature) {
            this.baseName = unwrappedQueryFeature;
            this.qfName = "query(" + this.baseName + ")";
        }

        List<String> lookingFor() {
            return List.of(this.qfName, "vespa.type.query." + this.baseName);
        }

        void accept(String key, String propValue) {
            if (key.equals(this.qfName)) {
                this.value = Tensor.from((String)propValue);
            } else {
                this.type = TensorType.fromSpec((String)propValue);
            }
        }

        Tensor extract() {
            if (this.value != null) {
                return this.value;
            }
            if (this.type != null) {
                return Tensor.Builder.of((TensorType)this.type).build();
            }
            return Tensor.from((double)0.0);
        }
    }

    static class InputResolver {
        final List<String> usedNormalizers = new ArrayList<String>();
        final List<String> fromQuery = new ArrayList<String>();
        final List<MatchFeatureInput> fromMF = new ArrayList<MatchFeatureInput>();
        private final Set<String> availableMatchFeatures;
        private final Map<String, String> renamedFeatures;
        private final Set<String> availableNormalizers;

        InputResolver(Set<String> availableMatchFeatures, Map<String, String> renamedFeatures, Set<String> availableNormalizers) {
            this.availableMatchFeatures = availableMatchFeatures;
            this.renamedFeatures = renamedFeatures;
            this.availableNormalizers = availableNormalizers;
        }

        void resolve(Collection<String> allInputs) {
            for (String input : allInputs) {
                String queryFeatureName = GlobalPhaseSetup.asQueryFeature(input);
                if (queryFeatureName != null) {
                    this.fromQuery.add(queryFeatureName);
                    continue;
                }
                if (this.availableNormalizers.contains(input)) {
                    this.usedNormalizers.add(input);
                    continue;
                }
                if (this.availableMatchFeatures.contains(input) || input.equals("relevanceScore")) {
                    String mfName = this.renamedFeatures.getOrDefault(input, input);
                    this.fromMF.add(new MatchFeatureInput(input, mfName));
                    continue;
                }
                if (this.renamedFeatures.containsValue(input)) {
                    this.fromMF.add(new MatchFeatureInput(input, input));
                    continue;
                }
                throw new IllegalArgumentException("Bad config, missing global-phase input: " + input);
            }
        }
    }
}

