/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors.tokenizer;

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class WordPieceTokenizer
implements Tokenizer {
    protected final TokenizerModel model;
    protected final PromptSupport promptSupport;
    protected final long sepToken;
    protected final long clsToken;
    protected final long unkToken;
    protected static final String sepString = "[SEP]";
    protected static final String clsString = "[CLS]";
    protected static final String unkString = "[UNK]";

    public WordPieceTokenizer(Path modelRoot) {
        Preconditions.checkArgument((boolean)modelRoot.resolve("tokenizer.json").toFile().exists(), (Object)("No tokenizer.json found in " + String.valueOf(modelRoot)));
        try {
            this.model = SafeTensorSupport.loadTokenizer(modelRoot);
            Preconditions.checkArgument((this.model.type == null || this.model.type.equalsIgnoreCase("WordPiece") ? 1 : 0) != 0, (Object)("Invalid model type: " + this.model.type));
            this.promptSupport = new PromptSupport(this.model);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        this.sepToken = (Long)this.model.vocabLookup.get((Object)sepString);
        this.clsToken = (Long)this.model.vocabLookup.get((Object)clsString);
        this.unkToken = (Long)this.model.vocabLookup.get((Object)unkString);
    }

    @Override
    public TokenizerModel getModel() {
        return this.model;
    }

    @Override
    public List<String> tokenize(String sentence) {
        sentence = this.preProcess(sentence);
        String[] whitespaceSplits = sentence.split("\\s+");
        ArrayList<String> tokens = new ArrayList<String>();
        tokens.add(clsString);
        List stringList = Arrays.stream(whitespaceSplits).flatMap(this::splitByPunctuation).map(str -> str.length() > 200 ? this.model.unkToken : str).flatMap(str -> {
            boolean isBad = false;
            ArrayList<Object> subTokens = new ArrayList<Object>();
            int start = 0;
            while (start < str.length()) {
                int end;
                Object curSubStr = null;
                for (end = str.length(); start < end; --end) {
                    Object substr = str.substring(start, end);
                    if (start > 0) {
                        substr = "##" + (String)substr;
                    }
                    if (!this.model.vocabLookup.containsKey(substr)) continue;
                    curSubStr = substr;
                    break;
                }
                if (curSubStr == null) {
                    isBad = true;
                    break;
                }
                subTokens.add(curSubStr);
                start = end;
            }
            if (isBad) {
                subTokens.add(this.model.unkToken);
            }
            return subTokens.stream();
        }).collect(Collectors.toList());
        tokens.addAll(stringList);
        tokens.add(sepString);
        return tokens;
    }

    protected String preProcess(String sentence) {
        sentence = sentence.toLowerCase().strip();
        return this.cleanText(sentence);
    }

    static boolean isControl(Integer c) {
        if (c == 9 || c == 10 || c == 13) {
            return false;
        }
        return Character.isISOControl(c);
    }

    static boolean isPunctuation(Integer cp) {
        if (cp >= 33 && cp <= 47 || cp >= 58 && cp <= 64 || cp >= 91 && cp <= 96 || cp >= 123 && cp <= 126) {
            return true;
        }
        int t = Character.getType(cp);
        return t >= 20 && t <= 24;
    }

    String cleanText(String sentence) {
        return sentence.codePoints().map(c -> {
            if (c == 0 || c == 65533 || WordPieceTokenizer.isControl(c)) {
                return -1;
            }
            if (Character.isWhitespace(c)) {
                return 32;
            }
            return c;
        }).filter(c -> c != -1).mapToObj(Character::toString).collect(Collectors.joining());
    }

    Stream<String> splitByPunctuation(String str) {
        int codepoint;
        ArrayList<String> result = new ArrayList<String>();
        int start = 0;
        for (int offset = 0; offset < str.length(); offset += Character.charCount(codepoint)) {
            codepoint = str.codePointAt(offset);
            if (!WordPieceTokenizer.isPunctuation(codepoint)) continue;
            if (offset != start) {
                result.add(str.substring(start, offset));
            }
            result.add(str.substring(offset, offset + Character.charCount(codepoint)));
            start = offset + Character.charCount(codepoint);
        }
        if (start != str.length()) {
            result.add(str.substring(start));
        }
        return result.stream();
    }

    @Override
    public long[] encode(String sentence) {
        return this.tokenize(sentence).stream().mapToLong(s -> (Long)this.model.vocabLookup.get(s)).toArray();
    }

    protected String postProcessToken(String decoded) {
        if (decoded.startsWith("##")) {
            return decoded.substring(2);
        }
        return " " + decoded;
    }

    @Override
    public String decode(long id) {
        return this.postProcessToken((String)this.model.vocabLookup.inverse().get((Object)id));
    }

    protected String postProcess(String sentence) {
        return sentence.strip();
    }

    @Override
    public String decode(long[] ids) {
        return this.postProcess(Arrays.stream(ids).mapToObj(this::decode).collect(Collectors.joining()));
    }

    @Override
    public Optional<PromptSupport> promptSupport() {
        return this.model.promptTemplates().isPresent() ? Optional.of(this.promptSupport) : Optional.empty();
    }
}

