/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors.tokenizer;

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BPETokenizer
implements Tokenizer {
    protected static final Logger logger = LoggerFactory.getLogger(BPETokenizer.class);
    protected final TokenizerModel model;
    protected final PromptSupport promptSupport;
    protected final ByteBuffer decodeBuffer = ByteBuffer.allocate(4);
    public static BiMap<Integer, Integer> alteredBytes;

    protected BPETokenizer(Path modelRoot) {
        Preconditions.checkArgument((boolean)modelRoot.resolve("tokenizer.json").toFile().exists(), (Object)("No tokenizer.json found in " + String.valueOf(modelRoot)));
        try {
            this.model = SafeTensorSupport.loadTokenizer(modelRoot);
            this.promptSupport = new PromptSupport(this.model);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public TokenizerModel getModel() {
        return this.model;
    }

    @Override
    public List<String> tokenize(String sentence) {
        if (sentence.isEmpty()) {
            return Collections.emptyList();
        }
        if (this.model.preTokenizer() == null && this.model.addedTokenPattern() == null) {
            Collections.singletonList(sentence);
        }
        ArrayList<String> sentencePieces = new ArrayList<String>();
        if (this.model.addedTokenPattern() != null) {
            String[] pieces;
            for (String piece : pieces = TokenizerModel.split(this.model.addedTokenPattern(), sentence, 0, true)) {
                if (piece.isEmpty()) continue;
                if (this.model.addedTokens().containsKey(piece)) {
                    sentencePieces.add(piece);
                    continue;
                }
                if (this.model.preTokenizer() != null) {
                    sentencePieces.addAll(this.model.preTokenizer().pretokenize(piece));
                    continue;
                }
                sentencePieces.add(piece);
            }
        } else if (this.model.preTokenizer() != null) {
            sentencePieces.addAll(this.model.preTokenizer().pretokenize(sentence));
        } else {
            sentencePieces.add(sentence);
        }
        return sentencePieces;
    }

    protected String preProcess(String sentence) {
        return sentence;
    }

    @Override
    public long[] encode(String rawSentence) {
        List<String> sentencePieces = this.tokenize(rawSentence);
        ArrayList<Long> allTokens = new ArrayList<Long>();
        for (String sentence : sentencePieces) {
            if (this.model.addedTokens() != null && this.model.addedTokens().containsKey(sentence)) {
                allTokens.add(this.model.addedTokens().get(sentence));
                continue;
            }
            ArrayList<Long> tokens = new ArrayList<Long>();
            sentence = this.preProcess(sentence);
            int[] codes = sentence.codePoints().toArray();
            for (int i = 0; i < codes.length; ++i) {
                String c = Character.toString(codes[i]);
                Long id = (Long)this.model.vocabLookup.get((Object)c);
                if (id != null) {
                    tokens.add(id);
                    continue;
                }
                if (this.model.byteFallback) {
                    String code = Character.toString(codes[i]);
                    byte[] chars = code.getBytes(StandardCharsets.UTF_8);
                    for (int k = 0; k < chars.length; ++k) {
                        long token = this.encodeCharacterAsToken(chars[k]);
                        tokens.add(token);
                    }
                    continue;
                }
                if (this.model.unkToken == null) continue;
                tokens.add((Long)this.model.vocabLookup.get((Object)this.model.unkToken));
            }
            while (true) {
                long bestId = -1L;
                long bestIdx = -1L;
                long bestRank = Long.MAX_VALUE;
                for (int i = 0; i < tokens.size() - 1; ++i) {
                    long rank;
                    Long id;
                    String token1 = this.decodeInternal((Long)tokens.get(i));
                    String token2 = this.decodeInternal((Long)tokens.get(i + 1));
                    String merge2 = String.format("%s %s", token1, token2);
                    String merge3 = String.format("%s%s", token1, token2);
                    if (!this.model.merges.containsKey(merge2) || (id = (Long)this.model.vocabLookup.get((Object)merge3)) == null || (rank = this.model.merges.get(merge2).longValue()) >= bestRank) continue;
                    bestId = id;
                    bestIdx = i;
                    bestRank = rank;
                }
                if (bestIdx == -1L) break;
                tokens.set((int)bestIdx, bestId);
                tokens.remove((int)bestIdx + 1);
            }
            allTokens.addAll(tokens);
        }
        return allTokens.stream().mapToLong(s -> s).toArray();
    }

    protected String postProcessToken(String decoded) {
        if (decoded == null) {
            decoded = this.model.unkToken;
        }
        return decoded;
    }

    @Override
    public String decode(long id) {
        return this.maybeDecodeTokenAsCharacter(id).map(c -> {
            if (Character.isUnicodeIdentifierPart(c.charValue()) || this.decodeBuffer.remaining() < 4) {
                this.decodeBuffer.put((byte)c.charValue());
                if (this.decodeBuffer.remaining() == 0) {
                    String s = new String(this.decodeBuffer.array());
                    this.decodeBuffer.rewind();
                    return s;
                }
                return "";
            }
            return Character.toString(c.charValue());
        }).orElseGet(() -> this.postProcessToken((String)this.model.vocabLookup.inverse().get((Object)id)));
    }

    protected abstract long encodeCharacterAsToken(byte var1);

    protected abstract Optional<Character> maybeDecodeTokenAsCharacter(long var1);

    protected String decodeInternal(long id) {
        return this.maybeDecodeTokenAsCharacter(id).map(Object::toString).orElseGet(() -> {
            String s = (String)this.model.vocabLookup.inverse().get((Object)id);
            if (s == null) {
                s = this.model.unkToken;
            }
            return s;
        });
    }

    protected String postProcess(String sentence) {
        return sentence;
    }

    @Override
    public String decode(long[] ids) {
        return this.postProcess(Arrays.stream(ids).mapToObj(this::decode).collect(Collectors.joining()));
    }

    @Override
    public Optional<PromptSupport> promptSupport() {
        return this.model.promptTemplates().isPresent() ? Optional.of(this.promptSupport) : Optional.empty();
    }

    static {
        HashBiMap tmpAlteredBytes = HashBiMap.create();
        int i = 0;
        for (int c = 0; c < 256; ++c) {
            if (c >= 33 && c <= 126 || c >= 161 && c <= 172 || c >= 174 && c <= 255) continue;
            int codepoint = i++ + 256;
            tmpAlteredBytes.put((Object)c, (Object)codepoint);
        }
        alteredBytes = ImmutableBiMap.copyOf((Map)tmpAlteredBytes);
    }
}

