/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.DenseSubspaceFunction;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class FilterSubspaces<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argument;
    private final DenseSubspaceFunction<NAMETYPE> function;

    private FilterSubspaces(TensorFunction<NAMETYPE> argument, DenseSubspaceFunction<NAMETYPE> function) {
        this.argument = argument;
        this.function = function;
    }

    public FilterSubspaces(TensorFunction<NAMETYPE> argument, String functionArg, TensorFunction<NAMETYPE> function) {
        this(argument, new DenseSubspaceFunction<NAMETYPE>(functionArg, function));
        Objects.requireNonNull(argument, "The argument cannot be null");
        Objects.requireNonNull(functionArg, "The functionArg cannot be null");
        Objects.requireNonNull(function, "The function cannot be null");
    }

    private TensorType outputType(TensorType inputType) {
        TensorType m = inputType.mappedSubtype();
        TensorType i = inputType.indexedSubtype();
        TensorType d = this.function.outputType(i);
        if (m.rank() < 1) {
            throw new IllegalArgumentException("filter_subspaces needs input with at least 1 mapped dimension, but got: " + String.valueOf(inputType));
        }
        return inputType;
    }

    public TensorFunction<NAMETYPE> argument() {
        return this.argument;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argument);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 1) {
            throw new IllegalArgumentException("FilterSubspaces must have 1 argument, got " + arguments.size());
        }
        return new FilterSubspaces<NAMETYPE>(arguments.get(0), this.function);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new FilterSubspaces<NAMETYPE>(this.argument.toPrimitive(), this.function.toPrimitive());
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return this.outputType(this.argument.type(context));
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor input = this.argument().evaluate(context);
        AddressSplitter splitter = new AddressSplitter(input.type());
        HashMap<TensorAddress, Tensor.Builder> builders = new HashMap<TensorAddress, Tensor.Builder>();
        Iterator<Tensor.Cell> iter = input.cellIterator();
        while (iter.hasNext()) {
            Tensor.Cell cell = iter.next();
            SplitAddr split = splitter.split(cell.getKey());
            Tensor.Builder builder = builders.computeIfAbsent(split.sparsePart(), k -> Tensor.Builder.of(splitter.denseType));
            builder.cell(split.densePart(), (double)cell.getValue());
        }
        Tensor.Builder builder = Tensor.Builder.of(splitter.fullType);
        for (Map.Entry entry : builders.entrySet()) {
            TensorAddress mappedAddr = (TensorAddress)entry.getKey();
            Tensor denseInput = ((Tensor.Builder)entry.getValue()).build();
            Tensor filterResult = this.function.map(denseInput).sum();
            if (filterResult.asDouble() == 0.0) continue;
            Iterator<Tensor.Cell> iter2 = denseInput.cellIterator();
            while (iter2.hasNext()) {
                Tensor.Cell cell = iter2.next();
                TensorAddress denseAddr = cell.getKey();
                TensorAddress fullAddr = splitter.combine(mappedAddr, denseAddr);
                builder.cell(fullAddr, (double)cell.getValue());
            }
        }
        return builder.build();
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "filter_subspaces(" + this.argument.toString(context) + ", " + String.valueOf(this.function) + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("filter_subspaces", this.argument, this.function);
    }

    static class AddressSplitter {
        final TensorType fullType;
        final TensorType sparseType;
        final TensorType denseType;

        AddressSplitter(TensorType fullType) {
            this.fullType = fullType;
            this.sparseType = fullType.mappedSubtype();
            this.denseType = fullType.indexedSubtype();
        }

        SplitAddr split(TensorAddress fullAddr) {
            TensorAddress.Builder mapAddrBuilder = new TensorAddress.Builder(this.sparseType);
            TensorAddress.Builder idxAddrBuilder = new TensorAddress.Builder(this.denseType);
            for (int i = 0; i < this.fullType.dimensions().size(); ++i) {
                TensorType.Dimension dim = this.fullType.dimensions().get(i);
                if (dim.isMapped()) {
                    mapAddrBuilder.add(dim.name(), fullAddr.objectLabel(i));
                    continue;
                }
                idxAddrBuilder.add(dim.name(), fullAddr.objectLabel(i));
            }
            TensorAddress mapAddr = mapAddrBuilder.build();
            TensorAddress idxAddr = idxAddrBuilder.build();
            return new SplitAddr(mapAddr, idxAddr);
        }

        TensorAddress combine(TensorAddress sparsePart, TensorAddress densePart) {
            TensorAddress.Builder addrBuilder = new TensorAddress.Builder(this.fullType);
            List<TensorType.Dimension> sparseDims = this.sparseType.dimensions();
            for (int i = 0; i < sparseDims.size(); ++i) {
                TensorType.Dimension dim = sparseDims.get(i);
                addrBuilder.add(dim.name(), sparsePart.objectLabel(i));
            }
            List<TensorType.Dimension> denseDims = this.denseType.dimensions();
            for (int i = 0; i < denseDims.size(); ++i) {
                TensorType.Dimension dim = denseDims.get(i);
                addrBuilder.add(dim.name(), densePart.objectLabel(i));
            }
            return addrBuilder.build();
        }
    }

    record SplitAddr(TensorAddress sparsePart, TensorAddress densePart) {
    }
}

