package ai.djl.huggingface.tokenizers;

import ai.djl.huggingface.tokenizers.jni.CharSpan;
import ai.djl.huggingface.tokenizers.jni.LibUtils;
import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import ai.djl.modality.nlp.preprocess.Tokenizer;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.util.Ec2Utils;
import ai.djl.util.NativeResource;
import ai.djl.util.PairList;
import ai.djl.util.Platform;
import ai.djl.util.Utils;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.class */
public final class HuggingFaceTokenizer extends NativeResource<Long> implements Tokenizer {
    private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class);
    private boolean addSpecialTokens;
    private boolean withOverflowingTokens;
    private Locale doLowerCase;
    private TruncationStrategy truncation;
    private PaddingStrategy padding;
    private int maxLength;
    private int stride;
    private int padToMultipleOf;
    private int modelMaxLength;

    /* loaded from: input_file:ai/djl/huggingface/tokenizers/HuggingFaceTokenizer$Builder.class */
    public static final class Builder {
        private NDManager manager;
        private Map<String, String> options = new ConcurrentHashMap();

        Builder() {
            this.options.put("addSpecialTokens", "true");
        }

        public Builder optManager(NDManager nDManager) {
            this.manager = nDManager;
            return this;
        }

        public Builder optTokenizerName(String str) {
            this.options.put("tokenizer", str);
            return this;
        }

        public Builder optTokenizerPath(Path path) {
            this.options.putIfAbsent("tokenizerPath", path.toString());
            return this;
        }

        public Builder optAddSpecialTokens(boolean z) {
            this.options.put("addSpecialTokens", String.valueOf(z));
            return this;
        }

        public Builder optWithOverflowingTokens(boolean z) {
            this.options.put("withOverflowingTokens", String.valueOf(z));
            return this;
        }

        public Builder optTruncation(boolean z) {
            this.options.put("truncation", String.valueOf(z));
            return this;
        }

        public Builder optTruncateFirstOnly() {
            this.options.put("truncation", TruncationStrategy.ONLY_FIRST.name());
            return this;
        }

        public Builder optTruncateSecondOnly() {
            this.options.put("truncation", TruncationStrategy.ONLY_SECOND.name());
            return this;
        }

        public Builder optPadding(boolean z) {
            this.options.put("padding", String.valueOf(z));
            return this;
        }

        public Builder optPadToMaxLength() {
            this.options.put("padding", PaddingStrategy.MAX_LENGTH.name());
            return this;
        }

        public Builder optMaxLength(int i) {
            this.options.put("maxLength", String.valueOf(i));
            return this;
        }

        public Builder optPadToMultipleOf(int i) {
            this.options.put("padToMultipleOf", String.valueOf(i));
            return this;
        }

        public Builder optStride(int i) {
            this.options.put("stride", String.valueOf(i));
            return this;
        }

        public Builder optDoLowerCase(boolean z) {
            this.options.put("doLowerCase", String.valueOf(z));
            return this;
        }

        public Builder optDoLowerCase(String str) {
            this.options.put("doLowerCase", str);
            return this;
        }

        public void configure(Map<String, ?> map) {
            for (Map.Entry<String, ?> entry : map.entrySet()) {
                this.options.put(entry.getKey(), entry.getValue().toString());
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        private HuggingFaceTokenizer managed(HuggingFaceTokenizer huggingFaceTokenizer) {
            if (this.manager != null) {
                this.manager.attachInternal(huggingFaceTokenizer.getUid(), new AutoCloseable[]{huggingFaceTokenizer});
            }
            return huggingFaceTokenizer;
        }

        public HuggingFaceTokenizer build() throws IOException {
            String str = this.options.get("tokenizer");
            if (str != null) {
                return managed(HuggingFaceTokenizer.newInstance(str, this.options));
            }
            String str2 = this.options.get("tokenizerPath");
            if (str2 == null) {
                throw new IllegalArgumentException("Missing tokenizer path.");
            }
            Path path = Paths.get(str2, new String[0]);
            if (!Files.isDirectory(path, new LinkOption[0])) {
                if (Files.exists(path, new LinkOption[0])) {
                    return managed(HuggingFaceTokenizer.newInstance(path, this.options));
                }
                throw new IOException("Tokenizer file not exits: " + path);
            }
            if (Files.exists(path.resolve("tokenizer.json"), new LinkOption[0])) {
                return managed(HuggingFaceTokenizer.newInstance(path, this.options));
            }
            Path resolve = path.resolve("vocab.json");
            Path resolve2 = path.resolve("merges.txt");
            if (Files.exists(resolve, new LinkOption[0]) && Files.exists(resolve2, new LinkOption[0])) {
                return managed(HuggingFaceTokenizer.newInstance(resolve, resolve2, this.options));
            }
            throw new IOException("tokenizer.json file not found.");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/huggingface/tokenizers/HuggingFaceTokenizer$PaddingStrategy.class */
    public enum PaddingStrategy {
        LONGEST,
        MAX_LENGTH,
        DO_NOT_PAD;

        static PaddingStrategy fromValue(String str) {
            if ("true".equals(str)) {
                return LONGEST;
            }
            if ("false".equals(str)) {
                return DO_NOT_PAD;
            }
            for (PaddingStrategy paddingStrategy : values()) {
                if (paddingStrategy.name().equalsIgnoreCase(str)) {
                    return paddingStrategy;
                }
            }
            throw new IllegalArgumentException("Invalid PaddingStrategy: " + str);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/huggingface/tokenizers/HuggingFaceTokenizer$TruncationStrategy.class */
    public enum TruncationStrategy {
        LONGEST_FIRST,
        ONLY_FIRST,
        ONLY_SECOND,
        DO_NOT_TRUNCATE;

        static TruncationStrategy fromValue(String str) {
            if ("true".equals(str)) {
                return LONGEST_FIRST;
            }
            if ("false".equals(str)) {
                return DO_NOT_TRUNCATE;
            }
            for (TruncationStrategy truncationStrategy : values()) {
                if (truncationStrategy.name().equalsIgnoreCase(str)) {
                    return truncationStrategy;
                }
            }
            throw new IllegalArgumentException("Invalid TruncationStrategy: " + str);
        }
    }

    private HuggingFaceTokenizer(long j, Map<String, String> map) {
        super(Long.valueOf(j));
        this.truncation = TruncationStrategy.LONGEST_FIRST;
        this.padding = PaddingStrategy.LONGEST;
        this.maxLength = TokenizersLibrary.LIB.getMaxLength(j);
        this.stride = TokenizersLibrary.LIB.getStride(j);
        this.padToMultipleOf = TokenizersLibrary.LIB.getPadToMultipleOf(j);
        if (map != null) {
            this.addSpecialTokens = Boolean.parseBoolean(map.getOrDefault("addSpecialTokens", "true"));
            this.withOverflowingTokens = Boolean.parseBoolean(map.getOrDefault("withOverflowingTokens", "false"));
            this.modelMaxLength = ArgumentsUtil.intValue(map, "modelMaxLength", 512);
            if (map.containsKey("truncation")) {
                this.truncation = TruncationStrategy.fromValue(map.get("truncation"));
            }
            if (map.containsKey("padding")) {
                this.padding = PaddingStrategy.fromValue(map.get("padding"));
            }
            this.maxLength = ArgumentsUtil.intValue(map, "maxLength", this.maxLength);
            this.stride = ArgumentsUtil.intValue(map, "stride", this.stride);
            this.padToMultipleOf = ArgumentsUtil.intValue(map, "padToMultipleOf", this.padToMultipleOf);
            String orDefault = map.getOrDefault("doLowerCase", "false");
            if ("true".equals(orDefault)) {
                this.doLowerCase = Locale.getDefault();
            } else if (!"false".equals(orDefault)) {
                this.doLowerCase = Locale.forLanguageTag(orDefault);
            }
        } else {
            this.addSpecialTokens = true;
            this.modelMaxLength = 512;
        }
        updateTruncationAndPadding();
    }

    public static HuggingFaceTokenizer newInstance(String str) {
        return newInstance(str, (Map<String, String>) null);
    }

    public static HuggingFaceTokenizer newInstance(String str, Map<String, String> map) {
        Ec2Utils.callHome("Huggingface");
        LibUtils.checkStatus();
        String envOrSystemProperty = Utils.getEnvOrSystemProperty("HF_TOKEN");
        if (map != null) {
            envOrSystemProperty = map.getOrDefault("hf_token", envOrSystemProperty);
        }
        return new HuggingFaceTokenizer(TokenizersLibrary.LIB.createTokenizer(str, envOrSystemProperty), map);
    }

    public static HuggingFaceTokenizer newInstance(Path path) throws IOException {
        return newInstance(path, (Map<String, String>) null);
    }

    public static HuggingFaceTokenizer newInstance(Path path, Map<String, String> map) throws IOException {
        if (Files.isDirectory(path, new LinkOption[0])) {
            path = path.resolve("tokenizer.json");
        }
        InputStream newInputStream = Files.newInputStream(path, new OpenOption[0]);
        try {
            HuggingFaceTokenizer newInstance = newInstance(newInputStream, map);
            if (newInputStream != null) {
                newInputStream.close();
            }
            return newInstance;
        } catch (Throwable th) {
            if (newInputStream != null) {
                try {
                    newInputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static HuggingFaceTokenizer newInstance(Path path, Path path2, Map<String, String> map) throws IOException {
        Ec2Utils.callHome("Huggingface");
        LibUtils.checkStatus();
        return new HuggingFaceTokenizer(TokenizersLibrary.LIB.createBpeTokenizer(path.toAbsolutePath().toString(), path2.toAbsolutePath().toString()), map);
    }

    public static HuggingFaceTokenizer newInstance(InputStream inputStream, Map<String, String> map) throws IOException {
        Ec2Utils.callHome("Huggingface");
        LibUtils.checkStatus();
        return new HuggingFaceTokenizer(TokenizersLibrary.LIB.createTokenizerFromString(Utils.toString(inputStream)), map);
    }

    public String getVersion() {
        return Platform.detectPlatform("tokenizers").getVersion();
    }

    public List<String> tokenize(String str) {
        return Arrays.asList(encode(str).getTokens());
    }

    public String buildSentence(List<String> list) {
        return String.join(" ", list).replace(" ##", "").trim();
    }

    public void close() {
        Long l = (Long) this.handle.getAndSet(null);
        if (l != null) {
            TokenizersLibrary.LIB.deleteTokenizer(l.longValue());
        }
    }

    public Encoding encode(String str, boolean z, boolean z2) {
        if (str == null) {
            throw new NullPointerException("text cannot be null");
        }
        if (this.doLowerCase != null) {
            str = str.toLowerCase(this.doLowerCase);
        }
        return toEncoding(TokenizersLibrary.LIB.encode(((Long) getHandle()).longValue(), str, z), z2);
    }

    public Encoding encode(String str) {
        return encode(str, this.addSpecialTokens, this.withOverflowingTokens);
    }

    public Encoding encode(String str, String str2, boolean z, boolean z2) {
        if (str == null || str2 == null) {
            throw new NullPointerException("text/text_pair cannot be null");
        }
        if (this.doLowerCase != null) {
            str = str.toLowerCase(this.doLowerCase);
            str2 = str2.toLowerCase(this.doLowerCase);
        }
        return toEncoding(TokenizersLibrary.LIB.encodeDual(((Long) getHandle()).longValue(), str, str2, z), z2);
    }

    public Encoding encode(String str, String str2) {
        return encode(str, str2, this.addSpecialTokens, this.withOverflowingTokens);
    }

    public Encoding encode(List<String> list, boolean z, boolean z2) {
        return encode((String[]) list.toArray(Utils.EMPTY_ARRAY), z, z2);
    }

    public Encoding encode(List<String> list) {
        return encode(list, this.addSpecialTokens, this.withOverflowingTokens);
    }

    public Encoding encode(String[] strArr, boolean z, boolean z2) {
        if (this.doLowerCase != null) {
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = strArr[i].toLowerCase(this.doLowerCase);
            }
        } else if (Arrays.stream(strArr).anyMatch((v0) -> {
            return Objects.isNull(v0);
        })) {
            throw new NullPointerException("input text cannot be null");
        }
        return toEncoding(TokenizersLibrary.LIB.encodeList(((Long) getHandle()).longValue(), strArr, z), z2);
    }

    public Encoding encode(String[] strArr) {
        return encode(strArr, this.addSpecialTokens, this.withOverflowingTokens);
    }

    public Encoding[] batchEncode(List<String> list, boolean z, boolean z2) {
        return batchEncode((String[]) list.toArray(Utils.EMPTY_ARRAY), z, z2);
    }

    public Encoding[] batchEncode(List<String> list) {
        return batchEncode(list, this.addSpecialTokens, this.withOverflowingTokens);
    }

    public Encoding[] batchEncode(String[] strArr, boolean z, boolean z2) {
        if (this.doLowerCase != null) {
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = strArr[i].toLowerCase(this.doLowerCase);
            }
        } else if (Arrays.stream(strArr).anyMatch((v0) -> {
            return Objects.isNull(v0);
        })) {
            throw new NullPointerException("input text cannot be null");
        }
        long[] batchEncode = TokenizersLibrary.LIB.batchEncode(((Long) getHandle()).longValue(), strArr, z);
        Encoding[] encodingArr = new Encoding[batchEncode.length];
        for (int i2 = 0; i2 < batchEncode.length; i2++) {
            encodingArr[i2] = toEncoding(batchEncode[i2], z2);
        }
        return encodingArr;
    }

    public Encoding[] batchEncode(String[] strArr) {
        return batchEncode(strArr, this.addSpecialTokens, this.withOverflowingTokens);
    }

    public Encoding[] batchEncode(PairList<String, String> pairList, boolean z, boolean z2) {
        String[] strArr = (String[]) pairList.keyArray(Utils.EMPTY_ARRAY);
        String[] strArr2 = (String[]) pairList.valueArray(Utils.EMPTY_ARRAY);
        if (this.doLowerCase != null) {
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = strArr[i].toLowerCase(this.doLowerCase);
            }
            for (int i2 = 0; i2 < strArr2.length; i2++) {
                strArr2[i2] = strArr2[i2].toLowerCase(this.doLowerCase);
            }
        } else {
            if (pairList.keys().stream().anyMatch((v0) -> {
                return Objects.isNull(v0);
            })) {
                throw new NullPointerException("text pair key cannot be null");
            }
            if (pairList.values().stream().anyMatch((v0) -> {
                return Objects.isNull(v0);
            })) {
                throw new NullPointerException("text pair value cannot be null");
            }
        }
        long[] batchEncodePair = TokenizersLibrary.LIB.batchEncodePair(((Long) getHandle()).longValue(), strArr, strArr2, z);
        Encoding[] encodingArr = new Encoding[batchEncodePair.length];
        for (int i3 = 0; i3 < batchEncodePair.length; i3++) {
            encodingArr[i3] = toEncoding(batchEncodePair[i3], z2);
        }
        return encodingArr;
    }

    public Encoding[] batchEncode(PairList<String, String> pairList) {
        return batchEncode(pairList, this.addSpecialTokens, this.withOverflowingTokens);
    }

    public String decode(long[] jArr, boolean z) {
        return TokenizersLibrary.LIB.decode(((Long) getHandle()).longValue(), jArr, z);
    }

    public String decode(long[] jArr) {
        return decode(jArr, !this.addSpecialTokens);
    }

    public String[] batchDecode(long[][] jArr, boolean z) {
        return TokenizersLibrary.LIB.batchDecode(((Long) getHandle()).longValue(), jArr, z);
    }

    public String[] batchDecode(long[][] jArr) {
        return batchDecode(jArr, !this.addSpecialTokens);
    }

    public String getTruncation() {
        return this.truncation.name();
    }

    public String getPadding() {
        return this.padding.name();
    }

    public int getMaxLength() {
        return this.maxLength;
    }

    public int getStride() {
        return this.stride;
    }

    public int getPadToMultipleOf() {
        return this.padToMultipleOf;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = builder();
        builder.configure(map);
        return builder;
    }

    private void updateTruncationAndPadding() {
        boolean z = this.truncation != TruncationStrategy.DO_NOT_TRUNCATE;
        if (this.padding == PaddingStrategy.MAX_LENGTH || z) {
            if (this.maxLength == -1) {
                logger.warn("maxLength is not explicitly specified, use modelMaxLength: {}", Integer.valueOf(this.modelMaxLength));
                this.maxLength = this.modelMaxLength;
            } else if (this.maxLength > this.modelMaxLength) {
                logger.warn("maxLength is greater then modelMaxLength, change to: {}", Integer.valueOf(this.modelMaxLength));
                this.maxLength = this.modelMaxLength;
            }
            if (this.padding == PaddingStrategy.MAX_LENGTH && z && this.padToMultipleOf != 0 && this.maxLength % this.padToMultipleOf != 0) {
                int i = (this.maxLength + this.padToMultipleOf) - (this.maxLength % this.padToMultipleOf);
                if (i > this.modelMaxLength) {
                    i -= this.padToMultipleOf;
                }
                logger.warn("maxLength ({}) is not a multiple of padToMultipleOf ({}), change to: {}", new Object[]{Integer.valueOf(this.maxLength), Integer.valueOf(this.padToMultipleOf), Integer.valueOf(i)});
                this.maxLength = i;
            }
        }
        if (z) {
            TokenizersLibrary.LIB.setTruncation(((Long) getHandle()).longValue(), this.maxLength, this.truncation.name(), this.stride);
        } else {
            TokenizersLibrary.LIB.disableTruncation(((Long) getHandle()).longValue());
        }
        if (this.padding == PaddingStrategy.DO_NOT_PAD) {
            TokenizersLibrary.LIB.disablePadding(((Long) getHandle()).longValue());
        } else {
            TokenizersLibrary.LIB.setPadding(((Long) getHandle()).longValue(), this.maxLength, this.padding.name(), this.padToMultipleOf);
        }
    }

    private Encoding toEncoding(long j, boolean z) {
        Encoding[] encodingArr;
        long[] tokenIds = TokenizersLibrary.LIB.getTokenIds(j);
        long[] typeIds = TokenizersLibrary.LIB.getTypeIds(j);
        String[] tokens = TokenizersLibrary.LIB.getTokens(j);
        long[] wordIds = TokenizersLibrary.LIB.getWordIds(j);
        long[] sequenceIds = TokenizersLibrary.LIB.getSequenceIds(j);
        long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(j);
        long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(j);
        CharSpan[] tokenCharSpans = TokenizersLibrary.LIB.getTokenCharSpans(j);
        boolean z2 = TokenizersLibrary.LIB.getOverflowCount(j) > 0;
        if (z) {
            long[] overflowing = TokenizersLibrary.LIB.getOverflowing(j);
            encodingArr = new Encoding[overflowing.length];
            for (int i = 0; i < overflowing.length; i++) {
                encodingArr[i] = toEncoding(overflowing[i], true);
            }
        } else {
            encodingArr = new Encoding[0];
        }
        TokenizersLibrary.LIB.deleteEncoding(j);
        return new Encoding(tokenIds, typeIds, tokens, wordIds, sequenceIds, attentionMask, specialTokenMask, tokenCharSpans, z2, encodingArr);
    }

    protected void finalize() throws Throwable {
        close();
        super/*java.lang.Object*/.finalize();
    }
}
