package ai.vespa.models.handler;

import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;

/* loaded from: input_file:ai/vespa/models/handler/ModelsEvaluationHandler.class */
public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
    private static final String missingValueKey = "missing-value";
    public static final String API_ROOT = "model-evaluation";
    public static final String VERSION_V1 = "v1";
    public static final String EVALUATE = "eval";
    private final ModelsEvaluator modelsEvaluator;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/models/handler/ModelsEvaluationHandler$ErrorResponse.class */
    public static class ErrorResponse extends Response {
        ErrorResponse(int i, String str) {
            super(i, "{\"error\":\"" + str + "\"}");
        }
    }

    /* loaded from: input_file:ai/vespa/models/handler/ModelsEvaluationHandler$Path.class */
    private static class Path {
        private final String[] segments;

        public Path(HttpRequest httpRequest) {
            this.segments = splitPath(httpRequest);
        }

        Optional<String> segment(int i) {
            return (i < 0 || i >= this.segments.length) ? Optional.empty() : Optional.of(this.segments[i]);
        }

        Optional<Integer> lastIndexOf(String str) {
            for (int length = this.segments.length - 1; length >= 0; length--) {
                if (this.segments[length].equalsIgnoreCase(str)) {
                    return Optional.of(Integer.valueOf(length));
                }
            }
            return Optional.empty();
        }

        public String[] range(int i, Optional<Integer> optional) {
            return (String[]) Arrays.copyOfRange(this.segments, i, optional.isPresent() ? optional.get().intValue() : this.segments.length);
        }

        private static String[] splitPath(HttpRequest httpRequest) {
            String lowerCase = httpRequest.getUri().getPath().toLowerCase();
            if (lowerCase.startsWith("/")) {
                lowerCase = lowerCase.substring("/".length());
            }
            if (lowerCase.endsWith("/")) {
                lowerCase = lowerCase.substring(0, lowerCase.length() - 1);
            }
            return lowerCase.split("/");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/models/handler/ModelsEvaluationHandler$Response.class */
    public static class Response extends HttpResponse {
        private final byte[] data;

        Response(int i, byte[] bArr) {
            super(i);
            this.data = bArr;
        }

        Response(int i, String str) {
            this(i, str.getBytes(Charset.forName("UTF-8")));
        }

        public String getContentType() {
            return "application/json";
        }

        public void render(OutputStream outputStream) throws IOException {
            outputStream.write(this.data);
        }
    }

    @Inject
    public ModelsEvaluationHandler(ComponentRegistry<ModelsEvaluator> componentRegistry, Executor executor) {
        this((ModelsEvaluator) componentRegistry.getComponent(ModelsEvaluator.class.getName()), executor);
    }

    public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Executor executor) {
        super(executor);
        this.modelsEvaluator = modelsEvaluator;
        if (modelsEvaluator == null) {
            throw new IllegalArgumentException("missing ModelsEvaluator");
        }
    }

    public HttpResponse handle(HttpRequest httpRequest) {
        Path path = new Path(httpRequest);
        Optional<String> segment = path.segment(0);
        Optional<String> segment2 = path.segment(1);
        Optional<String> segment3 = path.segment(2);
        try {
            if (segment.isEmpty() || !segment.get().equalsIgnoreCase(API_ROOT)) {
                throw new IllegalArgumentException("unknown API");
            }
            if (segment2.isEmpty() || !segment2.get().equalsIgnoreCase(VERSION_V1)) {
                throw new IllegalArgumentException("unknown API version");
            }
            if (segment3.isEmpty()) {
                return listAllModels(httpRequest);
            }
            Model requireModel = this.modelsEvaluator.requireModel(segment3.get());
            Optional<Integer> lastIndexOf = path.lastIndexOf(EVALUATE);
            String[] range = path.range(3, lastIndexOf);
            return lastIndexOf.isPresent() ? evaluateModel(httpRequest, requireModel, range) : listModelInformation(httpRequest, requireModel, range);
        } catch (IllegalArgumentException e) {
            return new ErrorResponse(404, Exceptions.toMessageString(e));
        } catch (IllegalStateException e2) {
            return new ErrorResponse(400, Exceptions.toMessageString(e2));
        }
    }

    private HttpResponse evaluateModel(HttpRequest httpRequest, Model model, String[] strArr) {
        FunctionEvaluator evaluatorOf = model.evaluatorOf(strArr);
        property(httpRequest, missingValueKey).ifPresent(str -> {
            evaluatorOf.setMissingValue(Tensor.from(str));
        });
        for (Map.Entry entry : evaluatorOf.function().argumentTypes().entrySet()) {
            Optional<String> property = property(httpRequest, (String) entry.getKey());
            if (property.isPresent()) {
                try {
                    evaluatorOf.bind((String) entry.getKey(), Tensor.from((TensorType) entry.getValue(), property.get()));
                } catch (IllegalArgumentException e) {
                    evaluatorOf.bind((String) entry.getKey(), property.get());
                }
            }
        }
        Tensor evaluate = evaluatorOf.evaluate();
        String lowerCase = property(httpRequest, "format.tensors").orElse("short").toLowerCase();
        boolean z = -1;
        switch (lowerCase.hashCode()) {
            case -1753112864:
                if (lowerCase.equals("long-value")) {
                    z = 3;
                    break;
                }
                break;
            case -1014535064:
                if (lowerCase.equals("string-long ")) {
                    z = 5;
                    break;
                }
                break;
            case -891985903:
                if (lowerCase.equals("string")) {
                    z = 4;
                    break;
                }
                break;
            case -129114912:
                if (lowerCase.equals("short-value")) {
                    z = 2;
                    break;
                }
                break;
            case 3327612:
                if (lowerCase.equals("long")) {
                    z = true;
                    break;
                }
                break;
            case 109413500:
                if (lowerCase.equals("short")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new Response(200, JsonFormat.encode(evaluate, true, false));
            case true:
                return new Response(200, JsonFormat.encode(evaluate, false, false));
            case true:
                return new Response(200, JsonFormat.encode(evaluate, true, true));
            case true:
                return new Response(200, JsonFormat.encode(evaluate, false, true));
            case true:
                return new Response(200, evaluate.toString(true, true).getBytes(StandardCharsets.UTF_8));
            case true:
                return new Response(200, evaluate.toString(true, false).getBytes(StandardCharsets.UTF_8));
            default:
                return new ErrorResponse(400, "Unknown tensor format '" + property(httpRequest, "format.tensors") + "'");
        }
    }

    private HttpResponse listAllModels(HttpRequest httpRequest) {
        Slime slime = new Slime();
        Cursor object = slime.setObject();
        this.modelsEvaluator.models().keySet().stream().sorted().forEach(str -> {
            object.setString(str, baseUrl(httpRequest) + str);
        });
        return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
    }

    private HttpResponse listModelInformation(HttpRequest httpRequest, Model model, String[] strArr) {
        Slime slime = new Slime();
        Cursor object = slime.setObject();
        object.setString("model", model.name());
        if (strArr.length == 0) {
            listFunctions(httpRequest, model, object);
        } else {
            listFunctionDetails(httpRequest, model, strArr, object);
        }
        return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
    }

    private void listFunctions(HttpRequest httpRequest, Model model, Cursor cursor) {
        Cursor array = cursor.setArray("functions");
        for (ExpressionFunction expressionFunction : model.functions()) {
            listFunctionDetails(httpRequest, model, new String[]{expressionFunction.getName()}, array.addObject());
        }
    }

    private void listFunctionDetails(HttpRequest httpRequest, Model model, String[] strArr, Cursor cursor) {
        String join = String.join(".", strArr);
        FunctionEvaluator evaluatorOf = model.evaluatorOf(strArr);
        cursor.setString("function", join);
        cursor.setString("info", baseUrl(httpRequest) + model.name() + "/" + join);
        cursor.setString(EVALUATE, baseUrl(httpRequest) + model.name() + "/" + join + "/eval");
        Cursor array = cursor.setArray("arguments");
        Map argumentTypes = evaluatorOf.function().argumentTypes();
        ArrayList<String> arrayList = new ArrayList(argumentTypes.keySet());
        Collections.sort(arrayList);
        for (String str : arrayList) {
            Cursor addObject = array.addObject();
            addObject.setString("name", str);
            addObject.setString("type", ((TensorType) argumentTypes.get(str)).toString());
        }
    }

    private Optional<String> property(HttpRequest httpRequest, String str) {
        return Optional.ofNullable(httpRequest.getProperty(str));
    }

    private String baseUrl(HttpRequest httpRequest) {
        URI uri = httpRequest.getUri();
        StringBuilder sb = new StringBuilder();
        sb.append(uri.getScheme()).append("://");
        if (httpRequest.getHeader("Host") != null) {
            sb.append(httpRequest.getHeader("Host"));
        } else {
            sb.append(uri.getHost());
            if (uri.getPort() >= 0) {
                sb.append(":").append(uri.getPort());
            }
        }
        sb.append("/").append(API_ROOT).append("/").append(VERSION_V1).append("/");
        return sb.toString();
    }
}
