package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

@Beta
/* loaded from: input_file:com/yahoo/tensor/functions/Reduce.class */
public class Reduce extends PrimitiveTensorFunction {
    private final TensorFunction argument;
    private final List<String> dimensions;
    private final Aggregator aggregator;

    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$Aggregator.class */
    public enum Aggregator {
        avg,
        count,
        prod,
        sum,
        max,
        min
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$AvgAggregator.class */
    public static class AvgAggregator extends ValueAggregator {
        private int valueCount;
        private double valueSum;

        private AvgAggregator() {
            super();
            this.valueCount = 0;
            this.valueSum = 0.0d;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public void aggregate(double d) {
            this.valueCount++;
            this.valueSum += d;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public double aggregatedValue() {
            return this.valueSum / this.valueCount;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$CountAggregator.class */
    public static class CountAggregator extends ValueAggregator {
        private int valueCount;

        private CountAggregator() {
            super();
            this.valueCount = 0;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public void aggregate(double d) {
            this.valueCount++;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public double aggregatedValue() {
            return this.valueCount;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$MaxAggregator.class */
    public static class MaxAggregator extends ValueAggregator {
        private double maxValue;

        private MaxAggregator() {
            super();
            this.maxValue = Double.MIN_VALUE;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public void aggregate(double d) {
            if (d > this.maxValue) {
                this.maxValue = d;
            }
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public double aggregatedValue() {
            return this.maxValue;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$MinAggregator.class */
    public static class MinAggregator extends ValueAggregator {
        private double minValue;

        private MinAggregator() {
            super();
            this.minValue = Double.MAX_VALUE;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public void aggregate(double d) {
            if (d < this.minValue) {
                this.minValue = d;
            }
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public double aggregatedValue() {
            return this.minValue;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$ProdAggregator.class */
    public static class ProdAggregator extends ValueAggregator {
        private double valueProd;

        private ProdAggregator() {
            super();
            this.valueProd = 1.0d;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public void aggregate(double d) {
            this.valueProd *= d;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public double aggregatedValue() {
            return this.valueProd;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$SumAggregator.class */
    public static class SumAggregator extends ValueAggregator {
        private double valueSum;

        private SumAggregator() {
            super();
            this.valueSum = 0.0d;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public void aggregate(double d) {
            this.valueSum += d;
        }

        @Override // com.yahoo.tensor.functions.Reduce.ValueAggregator
        public double aggregatedValue() {
            return this.valueSum;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Reduce$ValueAggregator.class */
    public static abstract class ValueAggregator {
        private ValueAggregator() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static ValueAggregator ofType(Aggregator aggregator) {
            switch (aggregator) {
                case avg:
                    return new AvgAggregator();
                case count:
                    return new CountAggregator();
                case prod:
                    return new ProdAggregator();
                case sum:
                    return new SumAggregator();
                case max:
                    return new MaxAggregator();
                case min:
                    return new MinAggregator();
                default:
                    throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
            }
        }

        public abstract void aggregate(double d);

        public abstract double aggregatedValue();
    }

    public Reduce(TensorFunction tensorFunction, Aggregator aggregator) {
        this(tensorFunction, aggregator, (List<String>) Collections.emptyList());
    }

    public Reduce(TensorFunction tensorFunction, Aggregator aggregator, String str) {
        this(tensorFunction, aggregator, (List<String>) Collections.singletonList(str));
    }

    public Reduce(TensorFunction tensorFunction, Aggregator aggregator, List<String> list) {
        Objects.requireNonNull(tensorFunction, "The argument tensor cannot be null");
        Objects.requireNonNull(aggregator, "The aggregator cannot be null");
        Objects.requireNonNull(list, "The dimensions cannot be null");
        this.argument = tensorFunction;
        this.aggregator = aggregator;
        this.dimensions = ImmutableList.copyOf(list);
    }

    public TensorFunction argument() {
        return this.argument;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction> functionArguments() {
        return Collections.singletonList(this.argument);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction replaceArguments(List<TensorFunction> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("Reduce must have 1 argument, got " + list.size());
        }
        return new Reduce(list.get(0), this.aggregator, this.dimensions);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction toPrimitive() {
        return new Reduce(this.argument.toPrimitive(), this.aggregator, this.dimensions);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext toStringContext) {
        return "reduce(" + this.argument.toString(toStringContext) + ", " + this.aggregator + commaSeparated(this.dimensions) + ")";
    }

    private String commaSeparated(List<String> list) {
        StringBuilder sb = new StringBuilder();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            sb.append(", ").append(it.next());
        }
        return sb.toString();
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext evaluationContext) {
        Tensor evaluate = this.argument.evaluate(evaluationContext);
        if (!this.dimensions.isEmpty() && !evaluate.type().dimensionNames().containsAll(this.dimensions)) {
            throw new IllegalArgumentException("Cannot reduce " + evaluate + " over dimensions " + this.dimensions + ": Not all those dimensions are present in this tensor");
        }
        if (this.dimensions.isEmpty() || this.dimensions.size() == evaluate.type().dimensions().size()) {
            return (evaluate.type().dimensions().size() == 1 && (evaluate instanceof IndexedTensor)) ? reduceIndexedVector((IndexedTensor) evaluate) : reduceAllGeneral(evaluate);
        }
        TensorType.Builder builder = new TensorType.Builder();
        for (TensorType.Dimension dimension : evaluate.type().dimensions()) {
            if (!this.dimensions.contains(dimension.name())) {
                builder.dimension(dimension);
            }
        }
        TensorType build = builder.build();
        HashMap hashMap = new HashMap();
        Iterator<Tensor.Cell> cellIterator = evaluate.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            TensorAddress reduceDimensions = reduceDimensions(next.getKey(), evaluate.type(), build);
            hashMap.putIfAbsent(reduceDimensions, ValueAggregator.ofType(this.aggregator));
            ((ValueAggregator) hashMap.get(reduceDimensions)).aggregate(next.getValue().doubleValue());
        }
        Tensor.Builder of = Tensor.Builder.of(build);
        for (Map.Entry entry : hashMap.entrySet()) {
            of.cell((TensorAddress) entry.getKey(), ((ValueAggregator) entry.getValue()).aggregatedValue());
        }
        return of.build();
    }

    private TensorAddress reduceDimensions(TensorAddress tensorAddress, TensorType tensorType, TensorType tensorType2) {
        HashSet hashSet = new HashSet();
        Iterator<String> it = this.dimensions.iterator();
        while (it.hasNext()) {
            hashSet.add(tensorType.indexOfDimension(it.next()).get());
        }
        String[] strArr = new String[tensorType2.dimensions().size()];
        int i = 0;
        for (int i2 = 0; i2 < tensorAddress.size(); i2++) {
            if (!hashSet.contains(Integer.valueOf(i2))) {
                int i3 = i;
                i++;
                strArr[i3] = tensorAddress.label(i2);
            }
        }
        return TensorAddress.of(strArr);
    }

    private Tensor reduceAllGeneral(Tensor tensor) {
        ValueAggregator ofType = ValueAggregator.ofType(this.aggregator);
        Iterator<Double> valueIterator = tensor.valueIterator();
        while (valueIterator.hasNext()) {
            ofType.aggregate(valueIterator.next().doubleValue());
        }
        return Tensor.Builder.of(TensorType.empty).cell(ofType.aggregatedValue(), new int[0]).build();
    }

    private Tensor reduceIndexedVector(IndexedTensor indexedTensor) {
        ValueAggregator ofType = ValueAggregator.ofType(this.aggregator);
        for (int i = 0; i < indexedTensor.dimensionSizes().size(0); i++) {
            ofType.aggregate(indexedTensor.get(i));
        }
        return Tensor.Builder.of(TensorType.empty).cell(ofType.aggregatedValue(), new int[0]).build();
    }
}
