/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;

/*
 * Uses 'sealed' constructs - enablewith --sealed true
 */
enum PoolingStrategy {
    MEAN{

        @Override
        public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) {
            Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
            Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
            Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
            Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
            int i = 0;
            while ((long)i < (Long)((TensorType.Dimension)type.dimensions().get(0)).size().get()) {
                builder.cell(averaged.get(TensorAddress.of((int[])new int[]{0, i})), new long[]{i});
                ++i;
            }
            return builder.build();
        }
    }
    ,
    CLS{

        @Override
        public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) {
            Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
            int i = 0;
            while ((long)i < (Long)((TensorType.Dimension)type.dimensions().get(0)).size().get()) {
                builder.cell(tokenEmbeddings.get(TensorAddress.of((int[])new int[]{0, 0, i})), new long[]{i});
                ++i;
            }
            return builder.build();
        }
    }
    ,
    NONE{

        @Override
        public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) {
            Tensor.Builder builder = Tensor.Builder.of((TensorType)type);
            int i = 0;
            while ((long)i < (Long)((TensorType.Dimension)type.dimensions().get(0)).size().get()) {
                builder.cell(tokenEmbeddings.get(TensorAddress.of((int[])new int[]{0, i})), new long[]{i});
                ++i;
            }
            return builder.build();
        }
    };


    abstract Tensor toSentenceEmbedding(TensorType var1, Tensor var2, Tensor var3);

    static PoolingStrategy fromString(String strategy) {
        return switch (strategy.toLowerCase()) {
            case "mean" -> MEAN;
            case "none" -> NONE;
            case "cls" -> CLS;
            default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy));
        };
    }
}

