package org.deeplearning4j.models.sequencevectors.transformers.impl.iterables;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.class */
public class ParallelTransformerIterator extends BasicTransformerIterator {
    protected BlockingQueue<Sequence<VocabWord>> buffer;
    protected BlockingQueue<LabelledDocument> stringBuffer;
    protected TokenizerThread[] threads;
    protected boolean underlyingHas;
    protected AtomicInteger processing;
    private static final Logger log = LoggerFactory.getLogger(ParallelTransformerIterator.class);
    protected static final AtomicInteger count = new AtomicInteger(0);

    /* loaded from: input_file:org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator$TokenizerThread.class */
    private static class TokenizerThread extends Thread implements Runnable {
        protected BlockingQueue<Sequence<VocabWord>> sequencesBuffer;
        protected BlockingQueue<LabelledDocument> stringsBuffer;
        protected SentenceTransformer sentenceTransformer;
        protected AtomicBoolean shouldWork = new AtomicBoolean(true);
        protected AtomicInteger processing;

        public TokenizerThread(int i, SentenceTransformer sentenceTransformer, BlockingQueue<LabelledDocument> blockingQueue, BlockingQueue<Sequence<VocabWord>> blockingQueue2, AtomicInteger atomicInteger) {
            this.stringsBuffer = blockingQueue;
            this.sequencesBuffer = blockingQueue2;
            this.sentenceTransformer = sentenceTransformer;
            this.processing = atomicInteger;
            setDaemon(true);
            setName("Tokenization thread " + i);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (this.shouldWork.get()) {
                try {
                    LabelledDocument take = this.stringsBuffer.take();
                    if (take != null && take.getContent() != null) {
                        this.processing.incrementAndGet();
                        Sequence<VocabWord> transformToSequence = this.sentenceTransformer.transformToSequence(take.getContent());
                        if (take.getLabels() != null) {
                            for (String str : take.getLabels()) {
                                if (str != null && !str.isEmpty()) {
                                    transformToSequence.addSequenceLabel(new VocabWord(1.0d, str));
                                }
                            }
                        }
                        if (transformToSequence != null) {
                            this.sequencesBuffer.put(transformToSequence);
                        }
                        this.processing.decrementAndGet();
                    }
                } catch (InterruptedException e) {
                    shutdown();
                    return;
                } catch (Exception e2) {
                    throw new RuntimeException(e2);
                }
            }
        }

        public void shutdown() {
            this.shouldWork.set(false);
        }
    }

    public ParallelTransformerIterator(@NonNull LabelAwareIterator labelAwareIterator, @NonNull SentenceTransformer sentenceTransformer) {
        this(labelAwareIterator, sentenceTransformer, true);
        if (labelAwareIterator == null) {
            throw new NullPointerException("iterator");
        }
        if (sentenceTransformer == null) {
            throw new NullPointerException("transformer");
        }
    }

    public ParallelTransformerIterator(@NonNull LabelAwareIterator labelAwareIterator, @NonNull SentenceTransformer sentenceTransformer, boolean z) {
        super(new AsyncLabelAwareIterator(labelAwareIterator, 512), sentenceTransformer);
        this.buffer = new LinkedBlockingQueue(1024);
        this.underlyingHas = true;
        this.processing = new AtomicInteger(0);
        if (labelAwareIterator == null) {
            throw new NullPointerException("iterator");
        }
        if (sentenceTransformer == null) {
            throw new NullPointerException("transformer");
        }
        this.allowMultithreading = z;
        this.stringBuffer = new LinkedBlockingQueue(512);
        this.threads = new TokenizerThread[z ? Math.max(Runtime.getRuntime().availableProcessors() / 2, 2) : 1];
        int i = 0;
        while (i < 256) {
            try {
                if (this.underlyingHas) {
                    this.underlyingHas = this.iterator.hasNextDocument();
                }
                if (this.underlyingHas) {
                    this.stringBuffer.put(this.iterator.nextDocument());
                } else {
                    i += 257;
                }
                i++;
            } catch (Exception e) {
            }
        }
        for (int i2 = 0; i2 < this.threads.length; i2++) {
            this.threads[i2] = new TokenizerThread(i2, sentenceTransformer, this.stringBuffer, this.buffer, this.processing);
            this.threads[i2].setDaemon(true);
            this.threads[i2].setName("ParallelTransformer thread " + i2);
            this.threads[i2].start();
        }
    }

    @Override // org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator
    public void reset() {
        this.iterator.shutdown();
        for (int i = 0; i < this.threads.length; i++) {
            if (this.threads[i] != null) {
                this.threads[i].shutdown();
                try {
                    this.threads[i].interrupt();
                } catch (Exception e) {
                }
            }
        }
    }

    @Override // org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator, java.util.Iterator
    public boolean hasNext() {
        if (this.underlyingHas) {
            this.underlyingHas = this.iterator.hasNextDocument();
        } else {
            this.underlyingHas = false;
        }
        return this.underlyingHas || this.buffer.size() > 0 || this.stringBuffer.size() > 0 || this.processing.get() > 0;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator, java.util.Iterator
    public Sequence<VocabWord> next() {
        try {
            if (this.underlyingHas) {
                this.stringBuffer.put(this.iterator.nextDocument());
            }
            return this.buffer.take();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
