package com.didalgo.gpt3;

import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/* loaded from: input_file:com/didalgo/gpt3/GPT3Tokenizer.class */
public class GPT3Tokenizer {
    private final Map<ByteSequence, Integer> encoder;
    private final Map<Integer, ByteSequence> decoder;
    private final Map<String, Integer> specialTokensEncoder;
    private final Map<Integer, String> specialTokensDecoder;
    private final Pattern pattern;
    private final Pattern specialPattern;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/didalgo/gpt3/GPT3Tokenizer$IntPair.class */
    public static class IntPair {
        int start;
        int end;

        IntPair(int i, int i2) {
            this.start = i;
            this.end = i2;
        }
    }

    public GPT3Tokenizer(Encoding encoding) {
        this.encoder = encoding.mergeableRanks();
        this.decoder = (Map) this.encoder.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getValue();
        }, (v0) -> {
            return v0.getKey();
        }));
        this.specialTokensEncoder = encoding.specialTokens();
        this.specialTokensDecoder = (Map) this.specialTokensEncoder.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getValue();
        }, (v0) -> {
            return v0.getKey();
        }));
        this.pattern = encoding.pattern();
        this.specialPattern = createSpecialRegex(encoding.specialTokens());
    }

    protected Pattern createSpecialRegex(Map<String, ?> map) {
        return Pattern.compile((String) map.keySet().stream().map(Pattern::quote).collect(Collectors.joining("|")));
    }

    public String decode(List<Integer> list) {
        return decodeImpl(list);
    }

    protected String decodeImpl(List<Integer> list) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        for (Integer num : list) {
            ByteSequence byteSequence = this.decoder.get(num);
            if (byteSequence != null) {
                byteArrayOutputStream.writeBytes(byteSequence.toByteArray());
            } else {
                byteArrayOutputStream.writeBytes(this.specialTokensDecoder.get(num).getBytes(StandardCharsets.ISO_8859_1));
            }
        }
        return byteArrayOutputStream.toString(StandardCharsets.UTF_8);
    }

    protected Pattern getTlSpecialRegex() {
        return this.specialPattern;
    }

    protected Pattern getTlRegex() {
        return this.pattern;
    }

    public List<Integer> encode(String str) {
        return encode(str, false);
    }

    public List<Integer> encode(String str, boolean z) {
        return encode(str, z ? this.specialTokensEncoder.keySet() : Set.of());
    }

    public List<Integer> encode(String str, Set<String> set) {
        return encodeImpl(str, set);
    }

    protected List<Integer> encodeImpl(String str, Set<String> set) {
        Matcher matcher;
        Pattern tlSpecialRegex = getTlSpecialRegex();
        Pattern tlRegex = getTlRegex();
        ArrayList arrayList = new ArrayList(str.length() / 4);
        int i = 0;
        while (true) {
            int i2 = i;
            int i3 = i2;
            while (true) {
                matcher = tlSpecialRegex.matcher(str.substring(i3));
                if (!matcher.find()) {
                    matcher = null;
                    break;
                }
                int start = i2 + matcher.start();
                if (set.contains(str.substring(start, start + matcher.group().length()))) {
                    break;
                }
                i3 = start + 1;
            }
            Matcher matcher2 = tlRegex.matcher(str.substring(i2, matcher != null ? i2 + matcher.start() : str.length()));
            while (matcher2.find()) {
                ByteSequence from = ByteSequence.from(matcher2.group());
                Integer num = this.encoder.get(from);
                if (num != null) {
                    arrayList.add(num);
                } else {
                    bytePairMerge(from, arrayList);
                }
            }
            if (matcher == null) {
                return arrayList;
            }
            arrayList.add(this.specialTokensEncoder.get(matcher.group()));
            i = i2 + matcher.end();
        }
    }

    protected int getRank(ByteSequence byteSequence, List<IntPair> list, int i) {
        if (i + 2 >= list.size()) {
            return Integer.MAX_VALUE;
        }
        Integer num = this.encoder.get(byteSequence.subSequence(list.get(i).start, list.get(i + 2).start));
        if (num != null) {
            return num.intValue();
        }
        return Integer.MAX_VALUE;
    }

    protected int bytePairMerge(ByteSequence byteSequence, Collection<Integer> collection) {
        ArrayList arrayList = new ArrayList(byteSequence.length() + 1);
        for (int i = 0; i <= byteSequence.length(); i++) {
            arrayList.add(new IntPair(i, Integer.MAX_VALUE));
        }
        for (int i2 = 0; i2 < arrayList.size() - 2; i2++) {
            int rank = getRank(byteSequence, arrayList, i2);
            if (rank != Integer.MAX_VALUE) {
                arrayList.get(i2).end = rank;
            }
        }
        while (arrayList.size() > 1) {
            int i3 = Integer.MAX_VALUE;
            int i4 = -1;
            for (int i5 = 0; i5 < arrayList.size() - 1; i5++) {
                int i6 = arrayList.get(i5).end;
                if (i6 < i3) {
                    i3 = i6;
                    i4 = i5;
                }
            }
            if (i3 == Integer.MAX_VALUE) {
                break;
            }
            arrayList.remove(i4 + 1);
            arrayList.get(i4).end = getRank(byteSequence, arrayList, i4);
            if (i4 > 0) {
                arrayList.get(i4 - 1).end = getRank(byteSequence, arrayList, i4 - 1);
            }
        }
        int i7 = 0;
        for (int i8 = 0; i8 < arrayList.size() - 1; i8++) {
            IntPair intPair = new IntPair(arrayList.get(i8).start, arrayList.get(i8 + 1).start);
            collection.add(this.encoder.get(byteSequence.subSequence(intPair.start, intPair.end)));
            i7++;
        }
        return i7;
    }
}
