/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.wordvectors;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.apache.commons.lang.ArrayUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;

public class WordVectorsImpl<T extends SequenceElement>
implements WordVectors {
    private static final long serialVersionUID = 78249242142L;
    protected int minWordFrequency = 5;
    protected WeightLookupTable<T> lookupTable;
    protected VocabCache<T> vocab;
    protected int layerSize = 100;
    protected transient ModelUtils<T> modelUtils = new BasicModelUtils();
    private boolean initDone = false;
    protected int numIterations = 1;
    protected int numEpochs = 1;
    protected double negative = 0.0;
    protected double sampling = 0.0;
    protected AtomicDouble learningRate = new AtomicDouble(0.025);
    protected double minLearningRate = 0.01;
    protected int window = 5;
    protected int batchSize;
    protected int learningRateDecayWords;
    protected boolean resetModel;
    protected boolean useAdeGrad;
    protected int workers = 1;
    protected boolean trainSequenceVectors = false;
    protected boolean trainElementsVectors = true;
    protected long seed;
    protected boolean useUnknown = false;
    protected int[] variableWindows;
    public static final String DEFAULT_UNK = "UNK";
    private String UNK = "UNK";
    protected Collection<String> stopWords = new ArrayList<String>();

    public int getLayerSize() {
        if (this.lookupTable != null && this.lookupTable.getWeights() != null) {
            return this.lookupTable.getWeights().columns();
        }
        return this.layerSize;
    }

    @Override
    public boolean hasWord(String word) {
        return this.vocab().indexOf(word) >= 0;
    }

    @Override
    public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
        return this.modelUtils.wordsNearestSum(positive, negative, top);
    }

    @Override
    public Collection<String> wordsNearestSum(INDArray words, int top) {
        return this.modelUtils.wordsNearestSum(words, top);
    }

    @Override
    public Collection<String> wordsNearest(INDArray words, int top) {
        return this.modelUtils.wordsNearest(words, top);
    }

    @Override
    public Collection<String> wordsNearestSum(String word, int n) {
        return this.modelUtils.wordsNearestSum(word, n);
    }

    @Override
    public Map<String, Double> accuracy(List<String> questions) {
        return this.modelUtils.accuracy(questions);
    }

    @Override
    public int indexOf(String word) {
        return this.vocab().indexOf(word);
    }

    @Override
    public List<String> similarWordsInVocabTo(String word, double accuracy) {
        return this.modelUtils.similarWordsInVocabTo(word, accuracy);
    }

    @Override
    public double[] getWordVector(String word) {
        INDArray r = this.getWordVectorMatrix(word);
        if (r == null) {
            return null;
        }
        return r.dup().data().asDouble();
    }

    @Override
    public INDArray getWordVectorMatrixNormalized(String word) {
        INDArray r = this.getWordVectorMatrix(word);
        if (r == null) {
            return null;
        }
        return r.div((Number)Nd4j.getBlasWrapper().nrm2(r));
    }

    @Override
    public INDArray getWordVectorMatrix(String word) {
        return this.lookupTable().vector(word);
    }

    @Override
    public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
        return this.modelUtils.wordsNearest(positive, negative, top);
    }

    @Override
    public INDArray getWordVectors(@NonNull Collection<String> labels) {
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        int[] indexes = new int[labels.size()];
        int cnt = 0;
        boolean useIndexUnknown = this.useUnknown && this.vocab.containsWord(this.getUNK());
        for (String label : labels) {
            indexes[cnt] = this.vocab.containsWord(label) ? this.vocab.indexOf(label) : (useIndexUnknown ? this.vocab.indexOf(this.getUNK()) : -1);
            ++cnt;
        }
        while (ArrayUtils.contains((int[])indexes, (int)-1)) {
            indexes = ArrayUtils.removeElement((int[])indexes, (int)-1);
        }
        if (indexes.length == 0) {
            return Nd4j.empty((DataType)((InMemoryLookupTable)this.lookupTable).getSyn0().dataType());
        }
        INDArray result = Nd4j.pullRows((INDArray)this.lookupTable.getWeights(), (int)1, (int[])indexes);
        return result;
    }

    @Override
    public INDArray getWordVectorsMean(Collection<String> labels) {
        INDArray array = this.getWordVectors(labels);
        return array.mean(new int[]{0});
    }

    @Override
    public Collection<String> wordsNearest(String word, int n) {
        return this.modelUtils.wordsNearest(word, n);
    }

    @Override
    public double similarity(String word, String word2) {
        return this.modelUtils.similarity(word, word2);
    }

    @Override
    public VocabCache<T> vocab() {
        return this.vocab;
    }

    @Override
    public WeightLookupTable lookupTable() {
        return this.lookupTable;
    }

    @Override
    public void setModelUtils(@NonNull ModelUtils modelUtils) {
        if (modelUtils == null) {
            throw new NullPointerException("modelUtils is marked @NonNull but is null");
        }
        if (this.lookupTable != null) {
            modelUtils.init(this.lookupTable);
            this.modelUtils = modelUtils;
        }
    }

    public void setLookupTable(@NonNull WeightLookupTable lookupTable) {
        if (lookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        this.lookupTable = lookupTable;
        if (this.modelUtils == null) {
            this.modelUtils = new BasicModelUtils();
        }
        this.modelUtils.init(lookupTable);
    }

    public void setVocab(VocabCache vocab) {
        this.vocab = vocab;
    }

    protected void update() {
        this.update(EnvironmentUtils.buildEnvironment(), Event.STANDALONE);
    }

    protected void update(Environment env, Event event) {
        if (!this.initDone) {
            this.initDone = true;
            Heartbeat heartbeat = Heartbeat.getInstance();
            Task task = new Task();
            task.setNumFeatures(this.layerSize);
            if (this.vocab != null) {
                task.setNumSamples(this.vocab.numWords());
            }
            task.setNetworkType(Task.NetworkType.DenseNetwork);
            task.setArchitectureType(Task.ArchitectureType.WORDVECTORS);
            heartbeat.reportEvent(event, env, task);
        }
    }

    public void loadWeightsInto(INDArray array) {
        array.assign(this.lookupTable.getWeights());
    }

    public long vocabSize() {
        return this.lookupTable.getWeights().size(0);
    }

    public int vectorSize() {
        return this.lookupTable.layerSize();
    }

    public boolean jsonSerializable() {
        return false;
    }

    @Override
    public boolean outOfVocabularySupported() {
        return false;
    }

    public int getMinWordFrequency() {
        return this.minWordFrequency;
    }

    public WeightLookupTable<T> getLookupTable() {
        return this.lookupTable;
    }

    public VocabCache<T> getVocab() {
        return this.vocab;
    }

    public ModelUtils<T> getModelUtils() {
        return this.modelUtils;
    }

    public int getWindow() {
        return this.window;
    }

    @Override
    public String getUNK() {
        return this.UNK;
    }

    @Override
    public void setUNK(String UNK) {
        this.UNK = UNK;
    }

    public Collection<String> getStopWords() {
        return this.stopWords;
    }
}

