package com.yahoo.vespa.model.application.validation.change;

import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.config.model.api.ServiceInfo;
import com.yahoo.vespa.model.application.validation.Validation;
import com.yahoo.vespa.model.container.ApplicationContainer;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.class */
public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValidator {
    private static final Logger log = Logger.getLogger(RestartOnDeployForOnnxModelChangesValidator.class.getName());

    @Override // com.yahoo.vespa.model.application.validation.change.ChangeValidator
    public void validate(Validation.ChangeContext changeContext) {
        if (changeContext.deployState().isHosted()) {
            for (ApplicationContainerCluster applicationContainerCluster : changeContext.model().getContainerClusters().values()) {
                ApplicationContainerCluster applicationContainerCluster2 = changeContext.previousModel().getContainerClusters().get(applicationContainerCluster.getName());
                if (applicationContainerCluster2 != null) {
                    Map<String, OnnxModelCost.ModelInfo> models = applicationContainerCluster2.onnxModelCostCalculator().models();
                    Map<String, OnnxModelCost.ModelInfo> models2 = applicationContainerCluster.onnxModelCostCalculator().models();
                    if (!enoughMemoryToAvoidRestart(applicationContainerCluster2, applicationContainerCluster, changeContext.deployState().getDeployLogger())) {
                        log.log(Level.FINE, "Validating %s, current Onnx models:%s, next Onnx models:%s".formatted(applicationContainerCluster, models, models2));
                        List<ConfigChangeAction> validateModelChanges = validateModelChanges(applicationContainerCluster, models, models2);
                        Objects.requireNonNull(changeContext);
                        validateModelChanges.forEach(changeContext::require);
                        List<ConfigChangeAction> validateSetOfModels = validateSetOfModels(applicationContainerCluster, models, models2);
                        Objects.requireNonNull(changeContext);
                        validateSetOfModels.forEach(changeContext::require);
                    }
                }
            }
        }
    }

    private List<ConfigChangeAction> validateModelChanges(ApplicationContainerCluster applicationContainerCluster, Map<String, OnnxModelCost.ModelInfo> map, Map<String, OnnxModelCost.ModelInfo> map2) {
        ArrayList arrayList = new ArrayList();
        for (OnnxModelCost.ModelInfo modelInfo : map2.values()) {
            if (map.containsKey(modelInfo.modelId())) {
                modelChanged(modelInfo, map.get(modelInfo.modelId())).ifPresent(str -> {
                    setRestartOnDeployAndAddRestartAction(arrayList, applicationContainerCluster, "Onnx model '%s' has changed (%s), need to restart services in %s".formatted(modelInfo.modelId(), str, applicationContainerCluster));
                });
            }
        }
        return arrayList;
    }

    private List<ConfigChangeAction> validateSetOfModels(ApplicationContainerCluster applicationContainerCluster, Map<String, OnnxModelCost.ModelInfo> map, Map<String, OnnxModelCost.ModelInfo> map2) {
        ArrayList arrayList = new ArrayList();
        Set<String> keySet = map.keySet();
        Set<String> keySet2 = map2.keySet();
        log.log(Level.FINE, "Checking if Onnx model set has changed (%s) -> (%s)".formatted(keySet, keySet2));
        if (!keySet.equals(keySet2)) {
            setRestartOnDeployAndAddRestartAction(arrayList, applicationContainerCluster, "Onnx model set has changed from %s to %s, need to restart services in %s".formatted(keySet, keySet2, applicationContainerCluster));
        }
        return arrayList;
    }

    private Optional<String> modelChanged(OnnxModelCost.ModelInfo modelInfo, OnnxModelCost.ModelInfo modelInfo2) {
        log.log(Level.FINE, "Checking if model has changed (%s) -> (%s)".formatted(modelInfo, modelInfo2));
        return modelInfo.estimatedCost() != modelInfo2.estimatedCost() ? Optional.of("estimated cost") : modelInfo.hash() != modelInfo2.hash() ? Optional.of("model hash") : !modelInfo.onnxModelOptions().equals(modelInfo2.onnxModelOptions()) ? Optional.of("model option(s)") : Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void setRestartOnDeployAndAddRestartAction(List<ConfigChangeAction> list, ApplicationContainerCluster applicationContainerCluster, String str) {
        log.log(Level.INFO, str);
        applicationContainerCluster.onnxModelCostCalculator().setRestartOnDeploy();
        applicationContainerCluster.onnxModelCostCalculator().store();
        list.add(new VespaRestartAction(applicationContainerCluster.id(), str, (List<ServiceInfo>) applicationContainerCluster.getContainers().stream().map((v0) -> {
            return v0.getServiceInfo();
        }).toList()));
    }

    private static boolean enoughMemoryToAvoidRestart(ApplicationContainerCluster applicationContainerCluster, ApplicationContainerCluster applicationContainerCluster2, DeployLogger deployLogger) {
        List<ApplicationContainer> containers = applicationContainerCluster2.getContainers();
        if (containers.isEmpty()) {
            return true;
        }
        double onnxModelCostInGb = onnxModelCostInGb(applicationContainerCluster);
        double onnxModelCostInGb2 = onnxModelCostInGb(applicationContainerCluster2);
        double orElseThrow = containers.stream().mapToDouble(applicationContainer -> {
            return applicationContainer.getHostResource().realResources().memoryGiB();
        }).min().orElseThrow();
        double d = onnxModelCostInGb + onnxModelCostInGb2;
        double max = Math.max(0.0d, (orElseThrow - 0.7d) - d);
        int heapSizePercentageOfAvailable = (int) ((max / orElseThrow) * applicationContainerCluster2.heapSizePercentageOfAvailable());
        String formatted = "Validating Onnx models memory usage for %s".formatted(applicationContainerCluster2);
        if (heapSizePercentageOfAvailable < 15) {
            deployLogger.log(Level.INFO, "%s, percentage of available memory too low (%d < %d) to avoid restart, consider a flavor with more memory to avoid this".formatted(formatted, Integer.valueOf(heapSizePercentageOfAvailable), 15));
            return false;
        }
        if (max < 0.6d) {
            deployLogger.log(Level.INFO, "%s, available memory too low (%.2f Gb < %.2f Gb) to avoid restart, consider a flavor with more memory to avoid this".formatted(formatted, Double.valueOf(max), Double.valueOf(0.6d)));
            return false;
        }
        log.log(Level.FINE, "%s, enough available memory (%.2f Gb) to avoid restart (models use %.2f Gb)".formatted(formatted, Double.valueOf(max), Double.valueOf(d)));
        return true;
    }

    private static double onnxModelCostInGb(ApplicationContainerCluster applicationContainerCluster) {
        return ((applicationContainerCluster.onnxModelCostCalculator().aggregatedModelCostInBytes() / 1024.0d) / 1024.0d) / 1024.0d;
    }
}
