/*
 * Decompiled with CFR 0.152.
 */
package com.google.genai;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.genai.LocalTokenizerLoader;
import com.google.genai.LocalTokenizerProcessor;
import com.google.genai.Token;
import com.google.genai.Transformers;
import com.google.genai.proto.SentencepieceModel;
import com.google.genai.types.ComputeTokensResult;
import com.google.genai.types.Content;
import com.google.genai.types.CountTokensConfig;
import com.google.genai.types.CountTokensResult;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import com.google.genai.types.TokensInfo;
import com.google.genai.types.Tool;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public final class LocalTokenizer {
    private final SentencepieceModel.ModelProto modelProto;
    private final LocalTokenizerProcessor tokenizer;

    public LocalTokenizer(String modelName) {
        String tokenizerName = LocalTokenizerLoader.getTokenizerName(modelName);
        this.modelProto = LocalTokenizerLoader.loadModelProto(tokenizerName);
        this.tokenizer = new LocalTokenizerProcessor(this.modelProto);
    }

    @VisibleForTesting
    LocalTokenizer(SentencepieceModel.ModelProto modelProto, LocalTokenizerProcessor tokenizer) {
        this.modelProto = modelProto;
        this.tokenizer = tokenizer;
    }

    public CountTokensResult countTokens(List<Content> contents, CountTokensConfig config) {
        List<Content> processedContents = contents;
        TextsAccumulator textAccumulator = new TextsAccumulator();
        if (config == null) {
            config = CountTokensConfig.builder().build();
        }
        textAccumulator.addContents(processedContents);
        if (config.tools().isPresent()) {
            textAccumulator.addTools(config.tools().get());
        }
        if (config.generationConfig().isPresent() && config.generationConfig().get().responseSchema().isPresent()) {
            textAccumulator.addSchema(config.generationConfig().get().responseSchema().get());
        }
        if (config.systemInstruction().isPresent()) {
            textAccumulator.addContents((List<Content>)Transformers.tContents(config.systemInstruction().get()));
        }
        int totalTokens = 0;
        for (String text : textAccumulator.getTexts()) {
            totalTokens += this.tokenizer.encode(text).size();
        }
        return CountTokensResult.builder().totalTokens(totalTokens).build();
    }

    public CountTokensResult countTokens(List<Content> contents) {
        return this.countTokens(contents, null);
    }

    public CountTokensResult countTokens(Content content, CountTokensConfig config) {
        return this.countTokens((List<Content>)ImmutableList.of((Object)content), config);
    }

    public CountTokensResult countTokens(Content content) {
        return this.countTokens(content, null);
    }

    public CountTokensResult countTokens(String content, CountTokensConfig config) {
        return this.countTokens((List<Content>)Transformers.tContents(content), config);
    }

    public CountTokensResult countTokens(String content) {
        return this.countTokens(content, null);
    }

    public ComputeTokensResult computeTokens(List<Content> contents) {
        List<Content> processedContents = contents;
        ArrayList<TokensInfo> tokenInfos = new ArrayList<TokensInfo>();
        for (Content content : processedContents) {
            if (!content.parts().isPresent()) continue;
            for (Part part : content.parts().get()) {
                TextsAccumulator partAccumulator = new TextsAccumulator();
                partAccumulator.addPart(part);
                ArrayList<Long> allTokenIds = new ArrayList<Long>();
                ArrayList<byte[]> allTokenBytes = new ArrayList<byte[]>();
                for (String text : partAccumulator.getTexts()) {
                    List<Token> tokens = this.tokenizer.encode(text);
                    for (Token token : tokens) {
                        allTokenIds.add(Long.valueOf(token.id()));
                        allTokenBytes.add(this.tokenStrToBytes(token.text(), this.modelProto.getPieces(token.id()).getType()));
                    }
                }
                tokenInfos.add(TokensInfo.builder().tokenIds(allTokenIds).tokens(allTokenBytes).role(content.role().orElse(null)).build());
            }
        }
        return ComputeTokensResult.builder().tokensInfo(tokenInfos).build();
    }

    public ComputeTokensResult computeTokens(Content content) {
        return this.computeTokens((List<Content>)ImmutableList.of((Object)content));
    }

    public ComputeTokensResult computeTokens(String content) {
        return this.computeTokens((List<Content>)Transformers.tContents(content));
    }

    private byte[] tokenStrToBytes(String token, SentencepieceModel.ModelProto.SentencePiece.Type type) {
        if (type == SentencepieceModel.ModelProto.SentencePiece.Type.BYTE) {
            return new byte[]{(byte)this.parseHexByte(token)};
        }
        return token.replace('\u2581', ' ').getBytes(StandardCharsets.UTF_8);
    }

    private int parseHexByte(String token) {
        if (token.length() != 6 || !token.startsWith("<0x") || !token.endsWith(">")) {
            throw new IllegalArgumentException("Invalid byte format: " + token);
        }
        try {
            int val = Integer.parseInt(token.substring(3, 5), 16);
            if (val >= 256) {
                throw new IllegalArgumentException("Byte value out of range: " + token);
            }
            return val;
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException("Invalid hex value: " + token, e);
        }
    }

    private static class TextsAccumulator {
        private static final Logger logger = Logger.getLogger(TextsAccumulator.class.getName());
        private final List<String> texts = new ArrayList<String>();

        private TextsAccumulator() {
        }

        public List<String> getTexts() {
            return this.texts;
        }

        public void addContents(List<Content> contents) {
            for (Content content : contents) {
                this.addContent(content);
            }
        }

        public void addContent(Content content) {
            Content countedContent = this.addContentAndBuildCounted(content);
            if (!Objects.equals(content, countedContent)) {
                logger.warning("Content contains unsupported types for token counting. Supported fields " + countedContent + ". Got " + content + ".");
            }
        }

        private Content addContentAndBuildCounted(Content content) {
            Content.Builder countedContentBuilder = Content.builder();
            content.role().ifPresent(countedContentBuilder::role);
            if (content.parts().isPresent()) {
                List<Part> countedParts = content.parts().get().stream().map(this::addPartAndBuildCounted).collect(Collectors.toList());
                countedContentBuilder.parts(countedParts);
            }
            return countedContentBuilder.build();
        }

        private void addPart(Part part) {
            this.addPartAndBuildCounted(part);
        }

        private Part addPartAndBuildCounted(Part part) {
            Part.Builder countedPartBuilder = Part.builder();
            if (part.fileData().isPresent() || part.inlineData().isPresent()) {
                throw new IllegalArgumentException("LocalTokenizers do not support non-text content types.");
            }
            part.videoMetadata().ifPresent(countedPartBuilder::videoMetadata);
            part.functionCall().ifPresent(fc -> {
                this.addFunctionCall((FunctionCall)fc);
                countedPartBuilder.functionCall((FunctionCall)fc);
            });
            part.functionResponse().ifPresent(fr -> {
                this.addFunctionResponse((FunctionResponse)fr);
                countedPartBuilder.functionResponse((FunctionResponse)fr);
            });
            part.text().ifPresent(text -> {
                this.texts.add((String)text);
                countedPartBuilder.text((String)text);
            });
            return countedPartBuilder.build();
        }

        public void addFunctionCall(FunctionCall functionCall) {
            functionCall.name().ifPresent(this.texts::add);
            functionCall.args().ifPresent(this::traverseMap);
        }

        public void addFunctionResponse(FunctionResponse functionResponse) {
            functionResponse.name().ifPresent(this.texts::add);
            functionResponse.response().ifPresent(this::traverseMap);
        }

        public void addTools(List<Tool> tools) {
            for (Tool tool : tools) {
                this.addTool(tool);
            }
        }

        public void addTool(Tool tool) {
            if (tool.functionDeclarations().isPresent()) {
                for (FunctionDeclaration functionDeclaration : tool.functionDeclarations().get()) {
                    this.addFunctionDeclaration(functionDeclaration);
                }
            }
        }

        private void addFunctionDeclaration(FunctionDeclaration functionDeclaration) {
            functionDeclaration.name().ifPresent(this.texts::add);
            functionDeclaration.description().ifPresent(this.texts::add);
            functionDeclaration.parameters().ifPresent(this::addSchema);
        }

        public void addSchema(Schema schema) {
            schema.format().ifPresent(this.texts::add);
            schema.description().ifPresent(this.texts::add);
            schema.enum_().ifPresent(this.texts::addAll);
            schema.required().ifPresent(this.texts::addAll);
            schema.items().ifPresent(this::addSchema);
            if (schema.properties().isPresent()) {
                for (Map.Entry<String, Schema> entry : schema.properties().get().entrySet()) {
                    this.texts.add(entry.getKey());
                    this.addSchema(entry.getValue());
                }
            }
            schema.example().ifPresent(this::traverseObject);
        }

        private void traverseObject(Object value) {
            if (value instanceof String) {
                this.texts.add((String)value);
            } else if (value instanceof Map) {
                Map map = (Map)value;
                this.traverseMap(map);
            } else if (value instanceof List) {
                for (Object item : (List)value) {
                    this.traverseObject(item);
                }
            }
        }

        private void traverseMap(Map<String, Object> map) {
            for (Map.Entry<String, Object> entry : map.entrySet()) {
                this.texts.add(entry.getKey());
                this.traverseObject(entry.getValue());
            }
        }
    }
}

