package ai.djl.basicdataset.tabular;

import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.training.dataset.Dataset;
import ai.djl.util.Progress;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

/* loaded from: input_file:ai/djl/basicdataset/tabular/MovieLens100k.class */
public final class MovieLens100k extends CsvDataset {
    private static final String ARTIFACT_ID = "movielens-100k";
    private static final String VERSION = "1.0";
    private static final String[] USER_FEATURES = {"user_id", "user_age", "user_gender", "user_occupation", "user_zipcode"};
    private static final String[] MOVIE_FEATURES = {"movie_id", "movie_title", "movie_release_date", "movie_video_release_date", "imdb_url", "unknown", "action", "adventure", "animation", "childrens", "comedy", "crime", "documentary", "drama", "fantasy", "film-noir", "horror", "musical", "mystery", "romance", "sci-fi", "thriller", "war", "western"};
    private Dataset.Usage usage;
    private MRL mrl;
    private boolean prepared;
    private Map<String, Map<String, String>> userFeaturesMap;
    private Map<String, Map<String, String>> movieFeaturesMap;

    /* renamed from: ai.djl.basicdataset.tabular.MovieLens100k$1, reason: invalid class name */
    /* loaded from: input_file:ai/djl/basicdataset/tabular/MovieLens100k$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$djl$training$dataset$Dataset$Usage = new int[Dataset.Usage.values().length];

        static {
            try {
                $SwitchMap$ai$djl$training$dataset$Dataset$Usage[Dataset.Usage.TRAIN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$djl$training$dataset$Dataset$Usage[Dataset.Usage.TEST.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$djl$training$dataset$Dataset$Usage[Dataset.Usage.VALIDATION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:ai/djl/basicdataset/tabular/MovieLens100k$Builder.class */
    public static final class Builder extends CsvDataset.CsvBuilder<Builder> {
        List<String> featureArray = new ArrayList(Arrays.asList("user_age", "user_gender", "user_occupation", "user_zipcode", "movie_title", "movie_genres"));
        List<String> movieGenres = new ArrayList(Arrays.asList("unknown", "action", "adventure", "animation", "childrens", "comedy", "crime", "documentary", "drama", "fantasy", "film-noir", "horror", "musical", "mystery", "romance", "sci-fi", "thriller", "war", "western"));
        Repository repository = BasicDatasets.REPOSITORY;
        String groupId = BasicDatasets.GROUP_ID;
        String artifactId = MovieLens100k.ARTIFACT_ID;
        Dataset.Usage usage = Dataset.Usage.TRAIN;

        Builder() {
            this.csvFormat = CSVFormat.TDF.builder().setHeader(HeaderEnum.class).setQuote((Character) null).build();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.basicdataset.tabular.CsvDataset.CsvBuilder
        /* renamed from: self */
        public Builder mo20self() {
            return this;
        }

        public Builder optUsage(Dataset.Usage usage) {
            this.usage = usage;
            return mo20self();
        }

        public Builder optRepository(Repository repository) {
            this.repository = repository;
            return mo20self();
        }

        public Builder optGroupId(String str) {
            this.groupId = str;
            return mo20self();
        }

        public Builder optArtifactId(String str) {
            if (str.contains(":")) {
                String[] split = str.split(":");
                this.groupId = split[0];
                this.artifactId = split[1];
            } else {
                this.artifactId = str;
            }
            return mo20self();
        }

        public List<String> getAvailableFeatures() {
            return this.featureArray;
        }

        public Builder addFeature(String str) {
            if (!getAvailableFeatures().contains(str)) {
                throw new IllegalArgumentException(String.format("Provided feature %s is not valid. Valid features are: %s", str, this.featureArray));
            }
            boolean z = -1;
            switch (str.hashCode()) {
                case -1836748993:
                    if (str.equals("movie_genres")) {
                        z = 5;
                        break;
                    }
                    break;
                case -1344465345:
                    if (str.equals("user_occupation")) {
                        z = 2;
                        break;
                    }
                    break;
                case -878403447:
                    if (str.equals("movie_title")) {
                        z = 4;
                        break;
                    }
                    break;
                case -507561547:
                    if (str.equals("user_gender")) {
                        z = true;
                        break;
                    }
                    break;
                case -266160501:
                    if (str.equals("user_age")) {
                        z = false;
                        break;
                    }
                    break;
                case 1244505114:
                    if (str.equals("user_zipcode")) {
                        z = 3;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    addNumericFeature(str);
                    break;
                case true:
                case true:
                    addCategoricalFeature(str, true);
                    break;
                case true:
                case true:
                    addCategoricalFeature(str, false);
                    break;
                case true:
                    this.movieGenres.forEach(str2 -> {
                        addNumericFeature(str2);
                    });
                    break;
            }
            return mo20self();
        }

        @Override // ai.djl.basicdataset.tabular.CsvDataset.CsvBuilder
        public MovieLens100k build() {
            if (this.features.isEmpty()) {
                this.featureArray.forEach(str -> {
                    addFeature(str);
                });
            }
            if (this.labels.isEmpty()) {
                addCategoricalLabel("rating", true);
            }
            return new MovieLens100k(this);
        }

        MRL getMrl() {
            return this.repository.dataset(Application.Tabular.ANY, this.groupId, this.artifactId, MovieLens100k.VERSION);
        }
    }

    /* loaded from: input_file:ai/djl/basicdataset/tabular/MovieLens100k$HeaderEnum.class */
    enum HeaderEnum {
        user_id,
        movie_id,
        rating,
        timestamp
    }

    MovieLens100k(Builder builder) {
        super(builder);
        this.usage = builder.usage;
        this.mrl = builder.getMrl();
    }

    @Override // ai.djl.basicdataset.tabular.CsvDataset, ai.djl.basicdataset.tabular.TabularDataset
    protected String getCell(long j, String str) {
        CSVRecord cSVRecord = this.csvRecords.get(Math.toIntExact(j));
        if (HeaderEnum.rating.toString().equals(str)) {
            return cSVRecord.get(HeaderEnum.rating);
        }
        if (str.startsWith("user")) {
            return this.userFeaturesMap.get(cSVRecord.get(HeaderEnum.user_id)).get(str);
        }
        return this.movieFeaturesMap.get(cSVRecord.get(HeaderEnum.movie_id)).get(str);
    }

    @Override // ai.djl.basicdataset.tabular.CsvDataset
    public void prepare(Progress progress) throws IOException {
        Path resolve;
        if (this.prepared) {
            return;
        }
        Artifact defaultArtifact = this.mrl.getDefaultArtifact();
        this.mrl.prepare(defaultArtifact, progress);
        Path resolve2 = this.mrl.getRepository().getResourceDirectory(defaultArtifact).resolve("ml-100k/ml-100k");
        this.userFeaturesMap = prepareFeaturesMap(resolve2.resolve("u.user"), USER_FEATURES);
        this.movieFeaturesMap = prepareFeaturesMap(resolve2.resolve("u.item"), MOVIE_FEATURES);
        switch (AnonymousClass1.$SwitchMap$ai$djl$training$dataset$Dataset$Usage[this.usage.ordinal()]) {
            case 1:
                resolve = resolve2.resolve("ua.base");
                break;
            case 2:
                resolve = resolve2.resolve("ua.test");
                break;
            case 3:
            default:
                throw new UnsupportedOperationException("Validation data not available");
        }
        this.csvUrl = resolve.toUri().toURL();
        super.prepare(progress);
        this.prepared = true;
    }

    private Map<String, Map<String, String>> prepareFeaturesMap(Path path, String[] strArr) throws IOException {
        List<CSVRecord> records = new CSVParser(new InputStreamReader(new BufferedInputStream(path.toUri().toURL().openStream()), StandardCharsets.UTF_8), CSVFormat.Builder.create(CSVFormat.newFormat('|')).build()).getRecords();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        for (CSVRecord cSVRecord : records) {
            ConcurrentHashMap concurrentHashMap2 = new ConcurrentHashMap();
            for (int i = 0; i < strArr.length; i++) {
                concurrentHashMap2.put(strArr[i], cSVRecord.get(i));
            }
            concurrentHashMap.put(cSVRecord.get(0), concurrentHashMap2);
        }
        return concurrentHashMap;
    }

    public static Builder builder() {
        return new Builder();
    }
}
