package com.unfbx.chatgpt.utils;

import cn.hutool.core.util.StrUtil;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.ModelType;
import com.unfbx.chatgpt.entity.chat.Message;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/unfbx/chatgpt/utils/TikTokensUtil.class */
public class TikTokensUtil {
    private static final Logger log = LoggerFactory.getLogger(TikTokensUtil.class);

    public static List<Integer> encode(@NotNull Encoding encoding, String str) {
        return StrUtil.isBlank(str) ? new ArrayList() : encoding.encode(str);
    }

    public static int tokens(@NotNull Encoding encoding, String str) {
        return encode(encoding, str).size();
    }

    public static String decode(@NotNull Encoding encoding, @NotNull List<Integer> list) {
        return encoding.decode(list);
    }

    public static Encoding getEncoding(@NotNull EncodingType encodingType) {
        return Encodings.newDefaultEncodingRegistry().getEncoding(encodingType);
    }

    public static List<Integer> encode(@NotNull EncodingType encodingType, String str) {
        return StrUtil.isBlank(str) ? new ArrayList() : getEncoding(encodingType).encode(str);
    }

    public static int tokens(@NotNull EncodingType encodingType, String str) {
        return encode(encodingType, str).size();
    }

    public static String decode(@NotNull EncodingType encodingType, @NotNull List<Integer> list) {
        return getEncoding(encodingType).decode(list);
    }

    public static Encoding getEncoding(@NotNull String str) {
        EncodingRegistry newDefaultEncodingRegistry = Encodings.newDefaultEncodingRegistry();
        ModelType modelTypeByName = getModelTypeByName(str);
        if (Objects.isNull(modelTypeByName)) {
            return null;
        }
        return newDefaultEncodingRegistry.getEncodingForModel(modelTypeByName);
    }

    public static List<Integer> encode(@NotNull String str, String str2) {
        if (StrUtil.isBlank(str2)) {
            return new ArrayList();
        }
        Encoding encoding = getEncoding(str);
        if (!Objects.isNull(encoding)) {
            return encoding.encode(str2);
        }
        log.warn("[{}]模型不存在或者暂不支持计算tokens，直接返回tokens==0");
        return new ArrayList();
    }

    public static int tokens(@NotNull String str, String str2) {
        return encode(str, str2).size();
    }

    public static int tokens(@NotNull String str, @NotNull List<Message> list) {
        int i = 0;
        int i2 = 0;
        if (str.equals("gpt-3.5-turbo-0301") || str.equals("gpt-3.5-turbo")) {
            i = 4;
            i2 = -1;
        }
        if (str.equals("gpt-4") || str.equals("gpt-4-0314")) {
            i = 3;
            i2 = 1;
        }
        int i3 = 0;
        for (Message message : list) {
            i3 = i3 + i + tokens(str, message.getContent()) + tokens(str, message.getRole()) + tokens(str, message.getName());
            if (StrUtil.isNotBlank(message.getName())) {
                i3 += i2;
            }
        }
        return i3 + 3;
    }

    public static String decode(@NotNull String str, @NotNull List<Integer> list) {
        return getEncoding(str).decode(list);
    }

    private static ModelType getModelTypeByName(String str) {
        for (ModelType modelType : ModelType.values()) {
            if (modelType.getName().equals(str)) {
                return modelType;
            }
        }
        log.warn("[{}]模型不存在或者暂不支持计算tokens", str);
        return null;
    }
}
