package ai.djl.basicdataset.utils;

import ai.djl.basicdataset.TextDataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Sampler;
import ai.djl.util.RandomUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/basicdataset/utils/FixedBucketSampler.class */
public class FixedBucketSampler implements Sampler {
    private static final Logger logger = LoggerFactory.getLogger(FixedBucketSampler.class);
    private int numBuckets;
    private int batchSize;
    private boolean shuffle;

    /* loaded from: input_file:ai/djl/basicdataset/utils/FixedBucketSampler$Iterate.class */
    private class Iterate implements Iterator<List<Long>> {
        private List<List<TextDataset.Sample>> buckets;
        private List<int[]> bucketBatch = new ArrayList();
        private int current;

        public Iterate(TextDataset textDataset) {
            this.buckets = new ArrayList(FixedBucketSampler.this.numBuckets);
            List<TextDataset.Sample> samples = textDataset.getSamples();
            int sentenceLength = samples.get(0).getSentenceLength();
            int sentenceLength2 = samples.get(samples.size() - 1).getSentenceLength();
            int max = Math.max(((1 + sentenceLength2) - sentenceLength) / FixedBucketSampler.this.numBuckets, 1);
            HashSet hashSet = new HashSet(FixedBucketSampler.this.numBuckets);
            for (int i = 0; i < FixedBucketSampler.this.numBuckets; i++) {
                hashSet.add(Integer.valueOf(Math.max(sentenceLength2 - (((FixedBucketSampler.this.numBuckets - i) - 1) * max), sentenceLength)));
            }
            int[] array = hashSet.stream().mapToInt((v0) -> {
                return v0.intValue();
            }).toArray();
            int i2 = 0;
            ArrayList arrayList = new ArrayList();
            for (TextDataset.Sample sample : samples) {
                if (sample.getSentenceLength() > array[i2]) {
                    if (!arrayList.isEmpty()) {
                        this.buckets.add(arrayList);
                        arrayList = new ArrayList();
                    }
                    i2++;
                }
                arrayList.add(sample);
            }
            if (!arrayList.isEmpty()) {
                this.buckets.add(arrayList);
            }
            for (int i3 = 0; i3 < this.buckets.size(); i3++) {
                List<TextDataset.Sample> list = this.buckets.get(i3);
                int i4 = 0;
                while (true) {
                    int i5 = i4;
                    if (i5 < list.size()) {
                        this.bucketBatch.add(new int[]{i3, i5});
                        i4 = i5 + FixedBucketSampler.this.batchSize;
                    }
                }
            }
            if (FixedBucketSampler.this.shuffle) {
                Collections.shuffle(this.bucketBatch, RandomUtils.RANDOM);
                this.buckets.forEach(list2 -> {
                    Collections.shuffle(list2, RandomUtils.RANDOM);
                });
            }
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.current < this.bucketBatch.size();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public List<Long> next() {
            int[] iArr = this.bucketBatch.get(this.current);
            ArrayList arrayList = new ArrayList();
            List<TextDataset.Sample> list = this.buckets.get(iArr[0]);
            int min = Math.min(list.size(), iArr[1] + FixedBucketSampler.this.batchSize);
            for (int i = iArr[1]; i < min; i++) {
                arrayList.add(Long.valueOf(list.get(i).getIndex()));
            }
            this.current++;
            return arrayList;
        }
    }

    public FixedBucketSampler(int i, int i2, boolean z) {
        this.numBuckets = i2;
        this.batchSize = i;
        this.shuffle = z;
        if (i == 1) {
            logger.warn("FixedBucketSampler is not meaningful with batch size 1.");
        }
    }

    public FixedBucketSampler(int i, int i2) {
        this(i, i2, true);
    }

    public FixedBucketSampler(int i) {
        this(i, 10);
    }

    public Iterator<List<Long>> sample(RandomAccessDataset randomAccessDataset) {
        if (randomAccessDataset instanceof TextDataset) {
            return new Iterate((TextDataset) randomAccessDataset);
        }
        throw new IllegalArgumentException("FixedBucketSampler can only be used with TextDataset");
    }

    public int getBatchSize() {
        return this.batchSize;
    }
}
