/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.BeamBatchTensorList;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.modality.nlp.generate.ContrastiveBatchTensorList;
import ai.djl.modality.nlp.generate.GreedyBatchTensorList;
import ai.djl.modality.nlp.generate.SearchConfig;
import ai.djl.modality.nlp.generate.StepGeneration;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import java.util.function.Function;
import java.util.stream.Collectors;

public class TextGenerator {
    private String searchName;
    private SearchConfig config;
    private Predictor<NDList, CausalLMOutput> predictor;
    private NDArray positionOffset;

    public TextGenerator(Predictor<NDList, CausalLMOutput> predictor, String searchName, SearchConfig searchConfig) {
        this.predictor = predictor;
        this.searchName = searchName;
        this.config = searchConfig;
    }

    public NDArray greedySearch(NDArray inputIds) throws TranslateException {
        NDArray attentionMask = this.prepareAttentionMaskOffset(inputIds, this.config);
        NDManager manager = inputIds.getManager();
        GreedyBatchTensorList searchState = new GreedyBatchTensorList(inputIds, null, null, attentionMask);
        do {
            try (NDScope ignore = new NDScope();){
                NDArray pastOutputIds = searchState.getPastOutputIds();
                NDArray nextInputIds = searchState.getNextInputIds();
                NDArray pastAttentionMask = searchState.getPastAttentionMask();
                NDList pastKeyValues = searchState.getPastKeyValues();
                long pastSeqLength = pastOutputIds == null ? 0L : pastOutputIds.getShape().getLastDimension();
                NDList modelInput = this.prepareInput(nextInputIds, pastAttentionMask, pastSeqLength, 1);
                if (pastKeyValues != null) {
                    modelInput.addAll(pastKeyValues);
                }
                CausalLMOutput modelOutput = this.predictor.predict(modelInput);
                NDArray outputIds = StepGeneration.greedyStepGen(modelOutput.getLogits());
                if (pastOutputIds == null) {
                    pastOutputIds = nextInputIds;
                    searchState.setPastOutputIds(pastOutputIds);
                } else {
                    pastOutputIds = pastOutputIds.concat(nextInputIds, 1);
                    searchState.setPastOutputIds(pastOutputIds);
                }
                nextInputIds = outputIds;
                searchState.setNextInputIds(nextInputIds);
                pastKeyValues = modelOutput.getPastKeyValuesList();
                searchState.setPastKeyValues(pastKeyValues);
                pastAttentionMask = pastAttentionMask.concat(manager.ones(new Shape(inputIds.getShape().get(0), 1L), DataType.INT64), 1);
                searchState.setPastAttentionMask(pastAttentionMask);
                NDScope.unregister(nextInputIds, pastAttentionMask, pastOutputIds);
                NDScope.unregister(pastKeyValues);
            }
        } while (searchState.getPastOutputIds().getShape().get(1) + 1L < (long)this.config.getMaxSeqLength());
        return searchState.getPastOutputIds().concat(searchState.getNextInputIds(), 1);
    }

    public NDArray beamSearch(NDArray inputIds) throws TranslateException {
        NDArray attentionMask = this.prepareAttentionMaskOffset(inputIds, this.config);
        NDManager manager = inputIds.getManager();
        long numBeam = this.config.getBeam();
        long numBatch = inputIds.getShape().get(0);
        BeamBatchTensorList searchState = new BeamBatchTensorList();
        long numHeads = 0L;
        long kvDim = 0L;
        do {
            if (searchState.getPastAttentionMask() == null) {
                NDList modelInput = this.prepareInput(inputIds, attentionMask, 0L, 1);
                CausalLMOutput modelOutput = this.predictor.predict(modelInput);
                NDArray allProbs = modelOutput.getLogits().get(":, -1, :", new Object[0]).softmax(1);
                NDList topK = allProbs.topK(Math.toIntExact(numBeam), -1, true, false);
                NDArray outputIds = ((NDArray)topK.get(1)).expandDims(2);
                NDArray lastProbs = ((NDArray)topK.get(0)).normalize(1.0, 1L);
                assert (outputIds.getShape().getShape().length == 3) : "Wrong shape";
                assert (lastProbs.getShape().getShape().length == 2) : "Wrong Shape";
                attentionMask = attentionMask.concat(manager.ones(new Shape(numBatch, 1L), DataType.INT64), -1).expandDims(1).repeat(1, numBeam);
                Function<NDArray, NDArray> fn = ndarray -> ndarray.expandDims(1).repeat(1, numBeam);
                NDList pastKeyValues = new NDList(modelOutput.getPastKeyValuesList().stream().map(fn).collect(Collectors.toList()));
                NDArray pastOutputIds = inputIds.expandDims(1).repeat(1, numBeam);
                searchState = new BeamBatchTensorList(outputIds, pastOutputIds, pastKeyValues, attentionMask, lastProbs);
                numHeads = ((NDArray)pastKeyValues.get(0)).getShape().get(2);
                kvDim = ((NDArray)pastKeyValues.get(0)).getShape().getLastDimension();
            }
            try (NDScope ignore = new NDScope();){
                long pastSeqLength = searchState.getPastOutputIds().getShape().getLastDimension();
                NDList modelInput = this.prepareInput(searchState.getNextInputIds().reshape(numBatch * numBeam, 1L), searchState.getPastAttentionMask().reshape(numBatch * numBeam, -1L), pastSeqLength, this.config.getBeam());
                long finalNumHeads = numHeads;
                long finalKvDim = kvDim;
                Function<NDArray, NDArray> fn = ndarray -> ndarray.reshape(numBatch * numBeam, finalNumHeads, pastSeqLength, finalKvDim);
                NDList pastKeyValues = new NDList(searchState.getPastKeyValues().stream().map(fn).collect(Collectors.toList()));
                modelInput.addAll(pastKeyValues);
                CausalLMOutput modelOutput = this.predictor.predict(modelInput);
                NDList generatedOutput = StepGeneration.beamStepGeneration(searchState.getLastProbs(), modelOutput.getLogits(), numBatch, numBeam);
                searchState = TextGenerator.updateSearchState(searchState, modelOutput, generatedOutput, manager);
                NDScope.unregister(searchState.getNextInputIds(), searchState.getPastOutputIds(), searchState.getPastAttentionMask(), searchState.getLastProbs());
                NDScope.unregister(searchState.getPastKeyValues());
            }
        } while (searchState.getPastOutputIds().getShape().getLastDimension() + 1L < (long)this.config.getMaxSeqLength());
        return searchState.getPastOutputIds().concat(searchState.getNextInputIds(), -1).reshape(numBatch * numBeam, -1L);
    }

    public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException {
        NDManager manager = inputIds.getManager();
        NDArray attentionMask = this.prepareAttentionMaskOffset(inputIds, this.config);
        ContrastiveBatchTensorList searchState = new ContrastiveBatchTensorList();
        do {
            if (searchState.getPastKeyValues() == null) {
                NDList modelInput = this.prepareInput(inputIds, attentionMask, 0L, 1);
                CausalLMOutput output = this.predictor.predict(modelInput);
                NDArray lastLogits = output.getLogits().get(":, -1, :", new Object[0]);
                searchState = new ContrastiveBatchTensorList(inputIds, attentionMask, output.getHiddenState(), lastLogits, output.getPastKeyValuesList(), new long[0]);
            }
            try (NDScope ignore = new NDScope();){
                NDArray topKIds = (NDArray)searchState.getLogits().topK(this.config.getK(), -1, true, false).get(1);
                NDArray candidateInputIds = topKIds.flatten().reshape(-1L, 1L);
                assert (candidateInputIds.getDataType() == DataType.INT64) : "inputIds datatype should be int64";
                assert (candidateInputIds.getShape().getShape().length == 2) : "shape not right";
                NDList kCopyPastKeyValues = new NDList(searchState.getPastKeyValues().stream().map(ndarray -> ndarray.repeat(0, this.config.getK())).collect(Collectors.toList()));
                assert (((NDArray)kCopyPastKeyValues.get(0)).getDataType() == DataType.FLOAT32) : "inputIds datatype should be Float32";
                long numBatch = topKIds.getShape().get(0);
                NDArray kCopyPastAttentionMask = searchState.getPastAttentionMask().repeat(0, this.config.getK());
                kCopyPastAttentionMask = kCopyPastAttentionMask.concat(manager.ones(new Shape(numBatch * (long)this.config.getK(), 1L), DataType.INT64), 1);
                assert (((NDArray)kCopyPastKeyValues.get(0)).getShape().get(2) + 1L == kCopyPastAttentionMask.getShape().getLastDimension()) : "attentionMask_seq = past_seq + new_input_seq";
                NDList candidateModelInput = this.prepareInput(candidateInputIds, kCopyPastAttentionMask, searchState.getPastOutputIds().getShape().getLastDimension(), this.config.getK());
                candidateModelInput.addAll(kCopyPastKeyValues);
                CausalLMOutput candidateOutput = this.predictor.predict(candidateModelInput);
                NDList generatedOutput = StepGeneration.constrastiveStepGeneration(topKIds, searchState.getLogits(), searchState.getPastHiddenStates(), candidateOutput.getHiddenState(), this.positionOffset, this.config.getAlpha());
                searchState = TextGenerator.updateSearchState(searchState, candidateOutput, generatedOutput, manager);
                NDScope.unregister(searchState.getPastOutputIds(), searchState.getPastAttentionMask(), searchState.getLogits(), searchState.getPastHiddenStates());
                NDScope.unregister(searchState.getPastKeyValues());
            }
        } while (searchState.getPastOutputIds().getShape().get(1) < (long)this.config.getMaxSeqLength());
        return searchState.getPastOutputIds();
    }

    private static BeamBatchTensorList updateSearchState(BeamBatchTensorList searchState, CausalLMOutput modelOutput, NDList generatedOutput, NDManager manager) {
        NDList pastKeyValues = searchState.getPastKeyValues();
        long numHeads = ((NDArray)pastKeyValues.get(0)).getShape().get(2);
        long kvDim = ((NDArray)pastKeyValues.get(0)).getShape().getLastDimension();
        long numBatch = searchState.getPastOutputIds().getShape().get(0);
        long numBeam = searchState.getPastOutputIds().getShape().get(1);
        long pastSeqLength = searchState.getPastOutputIds().getShape().getLastDimension();
        NDArray nextInputIds = (NDArray)generatedOutput.get(0);
        assert (nextInputIds.getShape().getShape().length == 3) : "Wrong Shape";
        NDArray newProbs = (NDArray)generatedOutput.get(1);
        NDArray sourceBeamSelected = (NDArray)generatedOutput.get(2);
        NDIndex sourceBeamIndex = new NDIndex("{}, {}, ...", manager.arange(0.0f, numBatch, 1.0f, DataType.INT64).expandDims(1).repeat(1, numBeam), sourceBeamSelected);
        NDArray pastOutputIds = searchState.getPastOutputIds().concat(searchState.getNextInputIds(), -1).get(sourceBeamIndex);
        Function<NDArray, NDArray> fn = ndarray -> ndarray.reshape(numBatch, numBeam, numHeads, pastSeqLength + 1L, kvDim).get(sourceBeamIndex);
        pastKeyValues = new NDList(modelOutput.getPastKeyValuesList().stream().map(fn).collect(Collectors.toList()));
        NDArray pastAttentionMask = searchState.getPastAttentionMask().concat(manager.ones(new Shape(numBatch, numBeam, 1L), DataType.INT64), -1).get(sourceBeamIndex);
        return new BeamBatchTensorList(nextInputIds, pastOutputIds, pastKeyValues, pastAttentionMask, newProbs);
    }

    private static ContrastiveBatchTensorList updateSearchState(ContrastiveBatchTensorList searchState, CausalLMOutput candidateOutput, NDList generatedOutput, NDManager manager) {
        assert (candidateOutput.getLogits().getShape().get(1) == 1L) : "dimension check: here, outputLogits corresponds to inputSeq == 1";
        long numBatch = searchState.getLogits().getShape().get(0);
        long logitsDim = searchState.getLogits().getShape().get(1);
        long pastSeqLengthPriorUpdate = searchState.getPastOutputIds().getShape().get(1);
        long numHeads = ((NDArray)searchState.getPastKeyValues().get(0)).getShape().get(1);
        long kvDim = ((NDArray)searchState.getPastKeyValues().get(0)).getShape().get(3);
        long hiddenDim = searchState.getPastHiddenStates().getShape().get(2);
        long k = candidateOutput.getLogits().getShape().get(0) / numBatch;
        NDArray select = (NDArray)generatedOutput.get(1);
        NDIndex selectIndex = new NDIndex("{}, {}, ...", manager.arange(0.0f, numBatch, 1.0f, DataType.INT64), select.flatten());
        NDArray nextLogits = candidateOutput.getLogits().reshape(numBatch, k, logitsDim).get(selectIndex);
        Function<NDArray, NDArray> fn = ndarray -> ndarray.reshape(numBatch, k, numHeads, pastSeqLengthPriorUpdate + 1L, kvDim).get(selectIndex);
        NDList nextPastKeyValue = new NDList(candidateOutput.getPastKeyValuesList().stream().map(fn).collect(Collectors.toList()));
        NDArray newHiddenState = candidateOutput.getHiddenState();
        assert (newHiddenState.getManager() == manager) : "possible leaky memory";
        NDArray nextPastHiddenStates = searchState.getPastHiddenStates().concat(newHiddenState.reshape(numBatch, k, 1L, hiddenDim).get(selectIndex), 1);
        NDArray outputIds = (NDArray)generatedOutput.get(0);
        NDArray nextOutputIds = searchState.getPastOutputIds().concat(outputIds, 1);
        NDArray nextPastAttentionMask = searchState.getPastAttentionMask().concat(manager.ones(new Shape(numBatch, 1L), DataType.INT64), 1);
        return new ContrastiveBatchTensorList(nextOutputIds, nextPastAttentionMask, nextPastHiddenStates, nextLogits, nextPastKeyValue, new long[0]);
    }

    private NDArray prepareAttentionMaskOffset(NDArray inputIds, SearchConfig config) {
        boolean suffixPadding = config.isSuffixPadding();
        NDManager manager = inputIds.getManager();
        int numBatch = Math.toIntExact(inputIds.getShape().get(0));
        int initSeqSize = Math.toIntExact(inputIds.getShape().get(1));
        NDArray attentionMask = manager.ones(new Shape(1L, inputIds.getShape().getLastDimension()), DataType.INT64).reshape(1L, -1L).repeat(0, numBatch);
        long[][] offset = new long[numBatch][1];
        for (int i = 0; i < numBatch; ++i) {
            int idx;
            long[] aSequence = inputIds.get("{},:", i).toLongArray();
            for (idx = 0; !(idx >= initSeqSize || suffixPadding && aSequence[idx] == config.getPadTokenId() || !suffixPadding && aSequence[idx] != config.getPadTokenId()); ++idx) {
            }
            attentionMask.set(new NDIndex("{},{}:{}", i, suffixPadding ? idx : 0, suffixPadding ? initSeqSize : idx), (Number)0);
            if (suffixPadding) continue;
            offset[i][0] = idx;
        }
        this.positionOffset = manager.create(offset);
        return attentionMask;
    }

    private NDList prepareInput(NDArray inputIds, NDArray attentionMask, long pastSeqLength, int repeat) {
        NDArray positionIds = inputIds.getManager().arange(pastSeqLength, pastSeqLength + inputIds.getShape().getLastDimension(), 1.0f, DataType.INT64).expandDims(0).repeat(0, inputIds.getShape().get(0));
        NDArray positionIdsShifted = positionIds.subi(this.positionOffset.repeat(0, repeat));
        positionIds = positionIdsShifted.maximum(positionIdsShifted.zerosLike());
        return new NDList(inputIds, positionIds, attentionMask);
    }

    public NDArray forward(NDArray inputIds) throws TranslateException {
        switch (this.searchName) {
            case "greedy": {
                return this.greedySearch(inputIds);
            }
            case "beam": {
                return this.beamSearch(inputIds);
            }
            case "contrastive": {
                return this.contrastiveSearch(inputIds);
            }
        }
        throw new IllegalArgumentException("searchName not correctly specified. Please choose among: {greedy, beam, contrastive}");
    }
}

