package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.Objects;

/* loaded from: input_file:ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.class */
public class OnnxEvaluatorOptions {
    private OrtSession.SessionOptions.OptLevel optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
    private OrtSession.SessionOptions.ExecutionMode executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
    private int interOpThreads;
    private int intraOpThreads;
    private int gpuDeviceNumber;
    private boolean gpuDeviceRequired;

    public OnnxEvaluatorOptions() {
        int max = Math.max(1, (int) Math.ceil(Runtime.getRuntime().availableProcessors() / 4.0d));
        this.interOpThreads = max;
        this.intraOpThreads = max;
        this.gpuDeviceNumber = -1;
        this.gpuDeviceRequired = false;
    }

    public OrtSession.SessionOptions getOptions(boolean z) throws OrtException {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        sessionOptions.setOptimizationLevel(this.optimizationLevel);
        sessionOptions.setExecutionMode(this.executionMode);
        sessionOptions.setInterOpNumThreads(this.executionMode == OrtSession.SessionOptions.ExecutionMode.PARALLEL ? this.interOpThreads : 1);
        sessionOptions.setIntraOpNumThreads(this.intraOpThreads);
        sessionOptions.setCPUArenaAllocator(false);
        if (z) {
            sessionOptions.addCUDA(this.gpuDeviceNumber);
        }
        return sessionOptions;
    }

    public void setExecutionMode(String str) {
        if ("parallel".equalsIgnoreCase(str)) {
            this.executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
        } else if ("sequential".equalsIgnoreCase(str)) {
            this.executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
        }
    }

    public void setInterOpThreads(int i) {
        if (i >= 0) {
            this.interOpThreads = i;
        }
    }

    public void setIntraOpThreads(int i) {
        if (i >= 0) {
            this.intraOpThreads = i;
        }
    }

    public void setThreads(int i, int i2) {
        this.interOpThreads = calculateThreads(i);
        this.intraOpThreads = calculateThreads(i2);
    }

    private static int calculateThreads(int i) {
        return i >= 0 ? i : Math.max(1, (int) Math.ceil(((-1.0d) * Runtime.getRuntime().availableProcessors()) / i));
    }

    public void setGpuDevice(int i, boolean z) {
        this.gpuDeviceNumber = i;
        this.gpuDeviceRequired = z;
    }

    public void setGpuDevice(int i) {
        this.gpuDeviceNumber = i;
    }

    public boolean requestingGpu() {
        return this.gpuDeviceNumber > -1;
    }

    public boolean gpuDeviceRequired() {
        return this.gpuDeviceRequired;
    }

    public int gpuDeviceNumber() {
        return this.gpuDeviceNumber;
    }

    public OnnxEvaluatorOptions copy() {
        OnnxEvaluatorOptions onnxEvaluatorOptions = new OnnxEvaluatorOptions();
        onnxEvaluatorOptions.gpuDeviceNumber = this.gpuDeviceNumber;
        onnxEvaluatorOptions.gpuDeviceRequired = this.gpuDeviceRequired;
        onnxEvaluatorOptions.executionMode = this.executionMode;
        onnxEvaluatorOptions.interOpThreads = this.interOpThreads;
        onnxEvaluatorOptions.intraOpThreads = this.intraOpThreads;
        onnxEvaluatorOptions.optimizationLevel = this.optimizationLevel;
        return onnxEvaluatorOptions;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        OnnxEvaluatorOptions onnxEvaluatorOptions = (OnnxEvaluatorOptions) obj;
        return this.interOpThreads == onnxEvaluatorOptions.interOpThreads && this.intraOpThreads == onnxEvaluatorOptions.intraOpThreads && this.gpuDeviceNumber == onnxEvaluatorOptions.gpuDeviceNumber && this.gpuDeviceRequired == onnxEvaluatorOptions.gpuDeviceRequired && this.optimizationLevel == onnxEvaluatorOptions.optimizationLevel && this.executionMode == onnxEvaluatorOptions.executionMode;
    }

    public int hashCode() {
        return Objects.hash(this.optimizationLevel, this.executionMode, Integer.valueOf(this.interOpThreads), Integer.valueOf(this.intraOpThreads), Integer.valueOf(this.gpuDeviceNumber), Boolean.valueOf(this.gpuDeviceRequired));
    }
}
