package de.kherud.llama;

import de.kherud.llama.args.MiroStat;
import de.kherud.llama.args.Sampler;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:de/kherud/llama/InferenceParameters.class */
public final class InferenceParameters extends JsonParameters {
    private static final String PARAM_PROMPT = "prompt";
    private static final String PARAM_INPUT_PREFIX = "input_prefix";
    private static final String PARAM_INPUT_SUFFIX = "input_suffix";
    private static final String PARAM_CACHE_PROMPT = "cache_prompt";
    private static final String PARAM_N_PREDICT = "n_predict";
    private static final String PARAM_TOP_K = "top_k";
    private static final String PARAM_TOP_P = "top_p";
    private static final String PARAM_MIN_P = "min_p";
    private static final String PARAM_TFS_Z = "tfs_z";
    private static final String PARAM_TYPICAL_P = "typical_p";
    private static final String PARAM_TEMPERATURE = "temperature";
    private static final String PARAM_DYNATEMP_RANGE = "dynatemp_range";
    private static final String PARAM_DYNATEMP_EXPONENT = "dynatemp_exponent";
    private static final String PARAM_REPEAT_LAST_N = "repeat_last_n";
    private static final String PARAM_REPEAT_PENALTY = "repeat_penalty";
    private static final String PARAM_FREQUENCY_PENALTY = "frequency_penalty";
    private static final String PARAM_PRESENCE_PENALTY = "presence_penalty";
    private static final String PARAM_MIROSTAT = "mirostat";
    private static final String PARAM_MIROSTAT_TAU = "mirostat_tau";
    private static final String PARAM_MIROSTAT_ETA = "mirostat_eta";
    private static final String PARAM_PENALIZE_NL = "penalize_nl";
    private static final String PARAM_N_KEEP = "n_keep";
    private static final String PARAM_SEED = "seed";
    private static final String PARAM_N_PROBS = "n_probs";
    private static final String PARAM_MIN_KEEP = "min_keep";
    private static final String PARAM_GRAMMAR = "grammar";
    private static final String PARAM_PENALTY_PROMPT = "penalty_prompt";
    private static final String PARAM_IGNORE_EOS = "ignore_eos";
    private static final String PARAM_LOGIT_BIAS = "logit_bias";
    private static final String PARAM_STOP = "stop";
    private static final String PARAM_SAMPLERS = "samplers";
    private static final String PARAM_STREAM = "stream";
    private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template";
    private static final String PARAM_USE_JINJA = "use_jinja";
    private static final String PARAM_MESSAGES = "messages";

    public InferenceParameters(String str) {
        setPrompt(str);
    }

    public InferenceParameters setPrompt(String str) {
        this.parameters.put(PARAM_PROMPT, toJsonString(str));
        return this;
    }

    public InferenceParameters setInputPrefix(String str) {
        this.parameters.put(PARAM_INPUT_PREFIX, toJsonString(str));
        return this;
    }

    public InferenceParameters setInputSuffix(String str) {
        this.parameters.put(PARAM_INPUT_SUFFIX, toJsonString(str));
        return this;
    }

    public InferenceParameters setCachePrompt(boolean z) {
        this.parameters.put(PARAM_CACHE_PROMPT, String.valueOf(z));
        return this;
    }

    public InferenceParameters setNPredict(int i) {
        this.parameters.put(PARAM_N_PREDICT, String.valueOf(i));
        return this;
    }

    public InferenceParameters setTopK(int i) {
        this.parameters.put(PARAM_TOP_K, String.valueOf(i));
        return this;
    }

    public InferenceParameters setTopP(float f) {
        this.parameters.put(PARAM_TOP_P, String.valueOf(f));
        return this;
    }

    public InferenceParameters setMinP(float f) {
        this.parameters.put(PARAM_MIN_P, String.valueOf(f));
        return this;
    }

    public InferenceParameters setTfsZ(float f) {
        this.parameters.put(PARAM_TFS_Z, String.valueOf(f));
        return this;
    }

    public InferenceParameters setTypicalP(float f) {
        this.parameters.put(PARAM_TYPICAL_P, String.valueOf(f));
        return this;
    }

    public InferenceParameters setTemperature(float f) {
        this.parameters.put(PARAM_TEMPERATURE, String.valueOf(f));
        return this;
    }

    public InferenceParameters setDynamicTemperatureRange(float f) {
        this.parameters.put(PARAM_DYNATEMP_RANGE, String.valueOf(f));
        return this;
    }

    public InferenceParameters setDynamicTemperatureExponent(float f) {
        this.parameters.put(PARAM_DYNATEMP_EXPONENT, String.valueOf(f));
        return this;
    }

    public InferenceParameters setRepeatLastN(int i) {
        this.parameters.put(PARAM_REPEAT_LAST_N, String.valueOf(i));
        return this;
    }

    public InferenceParameters setRepeatPenalty(float f) {
        this.parameters.put(PARAM_REPEAT_PENALTY, String.valueOf(f));
        return this;
    }

    public InferenceParameters setFrequencyPenalty(float f) {
        this.parameters.put(PARAM_FREQUENCY_PENALTY, String.valueOf(f));
        return this;
    }

    public InferenceParameters setPresencePenalty(float f) {
        this.parameters.put(PARAM_PRESENCE_PENALTY, String.valueOf(f));
        return this;
    }

    public InferenceParameters setMiroStat(MiroStat miroStat) {
        this.parameters.put(PARAM_MIROSTAT, String.valueOf(miroStat.ordinal()));
        return this;
    }

    public InferenceParameters setMiroStatTau(float f) {
        this.parameters.put(PARAM_MIROSTAT_TAU, String.valueOf(f));
        return this;
    }

    public InferenceParameters setMiroStatEta(float f) {
        this.parameters.put(PARAM_MIROSTAT_ETA, String.valueOf(f));
        return this;
    }

    public InferenceParameters setPenalizeNl(boolean z) {
        this.parameters.put(PARAM_PENALIZE_NL, String.valueOf(z));
        return this;
    }

    public InferenceParameters setNKeep(int i) {
        this.parameters.put(PARAM_N_KEEP, String.valueOf(i));
        return this;
    }

    public InferenceParameters setSeed(int i) {
        this.parameters.put(PARAM_SEED, String.valueOf(i));
        return this;
    }

    public InferenceParameters setNProbs(int i) {
        this.parameters.put(PARAM_N_PROBS, String.valueOf(i));
        return this;
    }

    public InferenceParameters setMinKeep(int i) {
        this.parameters.put(PARAM_MIN_KEEP, String.valueOf(i));
        return this;
    }

    public InferenceParameters setGrammar(String str) {
        this.parameters.put(PARAM_GRAMMAR, toJsonString(str));
        return this;
    }

    public InferenceParameters setPenaltyPrompt(String str) {
        this.parameters.put(PARAM_PENALTY_PROMPT, toJsonString(str));
        return this;
    }

    public InferenceParameters setPenaltyPrompt(int[] iArr) {
        if (iArr.length > 0) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            for (int i = 0; i < iArr.length; i++) {
                sb.append(iArr[i]);
                if (i < iArr.length - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            this.parameters.put(PARAM_PENALTY_PROMPT, sb.toString());
        }
        return this;
    }

    public InferenceParameters setIgnoreEos(boolean z) {
        this.parameters.put(PARAM_IGNORE_EOS, String.valueOf(z));
        return this;
    }

    public InferenceParameters setTokenIdBias(Map<Integer, Float> map) {
        if (!map.isEmpty()) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            int i = 0;
            for (Map.Entry<Integer, Float> entry : map.entrySet()) {
                sb.append("[").append(entry.getKey()).append(", ").append(entry.getValue()).append("]");
                int i2 = i;
                i++;
                if (i2 < map.size() - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            this.parameters.put(PARAM_LOGIT_BIAS, sb.toString());
        }
        return this;
    }

    public InferenceParameters disableTokenIds(Collection<Integer> collection) {
        if (!collection.isEmpty()) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            int i = 0;
            Iterator<Integer> it = collection.iterator();
            while (it.hasNext()) {
                sb.append("[").append(it.next()).append(", ").append(false).append("]");
                int i2 = i;
                i++;
                if (i2 < collection.size() - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            this.parameters.put(PARAM_LOGIT_BIAS, sb.toString());
        }
        return this;
    }

    public InferenceParameters setTokenBias(Map<String, Float> map) {
        if (!map.isEmpty()) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            int i = 0;
            for (Map.Entry<String, Float> entry : map.entrySet()) {
                sb.append("[").append(toJsonString(entry.getKey())).append(", ").append(entry.getValue()).append("]");
                int i2 = i;
                i++;
                if (i2 < map.size() - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            this.parameters.put(PARAM_LOGIT_BIAS, sb.toString());
        }
        return this;
    }

    public InferenceParameters disableTokens(Collection<String> collection) {
        if (!collection.isEmpty()) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            int i = 0;
            Iterator<String> it = collection.iterator();
            while (it.hasNext()) {
                sb.append("[").append(toJsonString(it.next())).append(", ").append(false).append("]");
                int i2 = i;
                i++;
                if (i2 < collection.size() - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            this.parameters.put(PARAM_LOGIT_BIAS, sb.toString());
        }
        return this;
    }

    public InferenceParameters setStopStrings(String... strArr) {
        if (strArr.length > 0) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            for (int i = 0; i < strArr.length; i++) {
                sb.append(toJsonString(strArr[i]));
                if (i < strArr.length - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            this.parameters.put(PARAM_STOP, sb.toString());
        }
        return this;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:7:0x0026. Please report as an issue. */
    public InferenceParameters setSamplers(Sampler... samplerArr) {
        if (samplerArr.length > 0) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            for (int i = 0; i < samplerArr.length; i++) {
                switch (samplerArr[i]) {
                    case TOP_K:
                        sb.append("\"top_k\"");
                        break;
                    case TOP_P:
                        sb.append("\"top_p\"");
                        break;
                    case MIN_P:
                        sb.append("\"min_p\"");
                        break;
                    case TEMPERATURE:
                        sb.append("\"temperature\"");
                        break;
                }
                if (i < samplerArr.length - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            this.parameters.put(PARAM_SAMPLERS, sb.toString());
        }
        return this;
    }

    public InferenceParameters setUseChatTemplate(boolean z) {
        this.parameters.put(PARAM_USE_JINJA, String.valueOf(z));
        return this;
    }

    public InferenceParameters setMessages(String str, List<Pair<String, String>> list) {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        if (str != null && !str.isEmpty()) {
            sb.append("{\"role\": \"system\", \"content\": ").append(toJsonString(str)).append("}");
            if (!list.isEmpty()) {
                sb.append(", ");
            }
        }
        for (int i = 0; i < list.size(); i++) {
            Pair<String, String> pair = list.get(i);
            String key = pair.getKey();
            String value = pair.getValue();
            if (!key.equals("user") && !key.equals("assistant")) {
                throw new IllegalArgumentException("Invalid role: " + key + ". Role must be 'user' or 'assistant'.");
            }
            sb.append("{\"role\":").append(toJsonString(key)).append(", \"content\": ").append(toJsonString(value)).append("}");
            if (i < list.size() - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        this.parameters.put(PARAM_MESSAGES, sb.toString());
        return this;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public InferenceParameters setStream(boolean z) {
        this.parameters.put(PARAM_STREAM, String.valueOf(z));
        return this;
    }

    @Override // de.kherud.llama.JsonParameters
    public /* bridge */ /* synthetic */ String toString() {
        return super.toString();
    }
}
