package ai.djl.basicdataset.tabular;

import ai.djl.basicdataset.tabular.utils.DynamicBuffer;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.basicdataset.tabular.utils.Featurizers;
import ai.djl.basicdataset.tabular.utils.PreparedFeaturizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ai/djl/basicdataset/tabular/TabularDataset.class */
public abstract class TabularDataset extends RandomAccessDataset {
    protected List<Feature> features;
    protected List<Feature> labels;

    /* loaded from: input_file:ai/djl/basicdataset/tabular/TabularDataset$BaseBuilder.class */
    public static abstract class BaseBuilder<T extends BaseBuilder<T>> extends RandomAccessDataset.BaseBuilder<T> {
        protected List<Feature> features = new ArrayList();
        protected List<Feature> labels = new ArrayList();
        protected boolean allowNoLabels;

        public T addFeature(Feature... featureArr) {
            Collections.addAll(this.features, featureArr);
            return (T) self();
        }

        public T addNumericFeature(String str) {
            this.features.add(new Feature(str, true));
            return (T) self();
        }

        public T addNumericFeature(String str, boolean z) {
            this.features.add(new Feature(str, Featurizers.getNumericFeaturizer(z)));
            return (T) self();
        }

        public T addCategoricalFeature(String str) {
            this.features.add(new Feature(str, false));
            return (T) self();
        }

        public T addCategoricalFeature(String str, boolean z) {
            this.features.add(new Feature(str, Featurizers.getStringFeaturizer(z)));
            return (T) self();
        }

        public T addCategoricalFeature(String str, Map<String, Integer> map, boolean z) {
            this.features.add(new Feature(str, map, z));
            return (T) self();
        }

        public T addLabel(Feature... featureArr) {
            Collections.addAll(this.labels, featureArr);
            return (T) self();
        }

        public T addNumericLabel(String str) {
            this.labels.add(new Feature(str, true));
            return (T) self();
        }

        public T addNumericLabel(String str, boolean z) {
            this.labels.add(new Feature(str, Featurizers.getNumericFeaturizer(z)));
            return (T) self();
        }

        public T addCategoricalLabel(String str) {
            this.labels.add(new Feature(str, false));
            return (T) self();
        }

        public T addCategoricalLabel(String str, boolean z) {
            this.labels.add(new Feature(str, Featurizers.getStringFeaturizer(z)));
            return (T) self();
        }

        public T addCategoricalLabel(String str, Map<String, Integer> map, boolean z) {
            this.labels.add(new Feature(str, map, z));
            return (T) self();
        }

        public T noLabels() {
            this.allowNoLabels = true;
            return (T) self();
        }
    }

    public TabularDataset(BaseBuilder<?> baseBuilder) {
        super(baseBuilder);
        this.features = baseBuilder.features;
        this.labels = baseBuilder.labels;
        if (this.features.isEmpty()) {
            throw new IllegalArgumentException("Missing features.");
        }
        if (this.labels.isEmpty() && !baseBuilder.allowNoLabels) {
            throw new IllegalArgumentException("Missing labels.");
        }
    }

    public int getFeatureSize() {
        return this.features.size();
    }

    public int getLabelSize() {
        return this.labels.size();
    }

    public Record get(NDManager nDManager, long j) {
        return new Record(getRowFeatures(nDManager, j, this.features), this.labels.isEmpty() ? new NDList() : getRowFeatures(nDManager, j, this.labels));
    }

    public NDList getRowFeatures(NDManager nDManager, long j, List<Feature> list) {
        DynamicBuffer dynamicBuffer = new DynamicBuffer();
        for (Feature feature : list) {
            feature.getFeaturizer().featurize(dynamicBuffer, getCell(j, feature.getName()));
        }
        return new NDList(new NDArray[]{nDManager.create(dynamicBuffer.getBuffer(), new Shape(new long[]{dynamicBuffer.getLength()}))});
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void prepareFeaturizers() {
        int intExact = Math.toIntExact(availableSize());
        ArrayList<Feature> arrayList = new ArrayList(this.features.size() + this.labels.size());
        arrayList.addAll(this.features);
        arrayList.addAll(this.labels);
        for (Feature feature : arrayList) {
            if (feature.getFeaturizer() instanceof PreparedFeaturizer) {
                PreparedFeaturizer preparedFeaturizer = (PreparedFeaturizer) feature.getFeaturizer();
                ArrayList arrayList2 = new ArrayList(Math.toIntExact(intExact));
                for (int i = 0; i < intExact; i++) {
                    arrayList2.add(getCell(i, feature.getName()));
                }
                preparedFeaturizer.prepare(arrayList2);
            }
        }
    }

    protected abstract String getCell(long j, String str);
}
