package ai.vespa.embedding;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;

/* loaded from: input_file:ai/vespa/embedding/PoolingStrategy.class */
public enum PoolingStrategy {
    MEAN { // from class: ai.vespa.embedding.PoolingStrategy.1
        @Override // ai.vespa.embedding.PoolingStrategy
        public Tensor toSentenceEmbedding(TensorType tensorType, Tensor tensor, Tensor tensor2) {
            Tensor.Builder of = Tensor.Builder.of(tensorType);
            Tensor join = tensor.sum("d1").join(tensor2.expand("d0").sum("d1"), (d, d2) -> {
                return d / d2;
            });
            for (int i = 0; i < ((Long) ((TensorType.Dimension) tensorType.dimensions().get(0)).size().get()).longValue(); i++) {
                of.cell(join.get(TensorAddress.of(new int[]{0, i})), new long[]{i});
            }
            return of.build();
        }
    },
    CLS { // from class: ai.vespa.embedding.PoolingStrategy.2
        @Override // ai.vespa.embedding.PoolingStrategy
        public Tensor toSentenceEmbedding(TensorType tensorType, Tensor tensor, Tensor tensor2) {
            Tensor.Builder of = Tensor.Builder.of(tensorType);
            for (int i = 0; i < ((Long) ((TensorType.Dimension) tensorType.dimensions().get(0)).size().get()).longValue(); i++) {
                of.cell(tensor.get(TensorAddress.of(new int[]{0, 0, i})), new long[]{i});
            }
            return of.build();
        }
    };

    public abstract Tensor toSentenceEmbedding(TensorType tensorType, Tensor tensor, Tensor tensor2);

    public static PoolingStrategy fromString(String str) {
        String lowerCase = str.toLowerCase();
        boolean z = -1;
        switch (lowerCase.hashCode()) {
            case 98602:
                if (lowerCase.equals("cls")) {
                    z = true;
                    break;
                }
                break;
            case 3347397:
                if (lowerCase.equals("mean")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MEAN;
            case true:
                return CLS;
            default:
                throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(str));
        }
    }
}
