package org.opensearch.ml.common.connector.functions.postprocess;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

/* loaded from: input_file:org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.class */
public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {
    @Override // org.opensearch.ml.common.connector.functions.postprocess.ConnectorPostProcessFunction
    public void validate(Object obj) {
        if (!(obj instanceof List)) {
            throw new IllegalArgumentException("Post process function input is not a List.");
        }
        List list = (List) obj;
        if (list.isEmpty()) {
            return;
        }
        if (!(list.get(0) instanceof Map)) {
            throw new IllegalArgumentException("Post process function input is not a List of Map.");
        }
        Map map = (Map) list.get(0);
        if (map.isEmpty() || !map.containsKey("index") || !map.containsKey("relevance_score")) {
            throw new IllegalArgumentException("The rerank result should contain index and relevance_score.");
        }
    }

    @Override // org.opensearch.ml.common.connector.functions.postprocess.ConnectorPostProcessFunction
    public List<ModelTensor> process(List<Map<String, Object>> list, MLResultDataType mLResultDataType) {
        ArrayList arrayList = new ArrayList();
        if (list.size() > 0) {
            Double[] dArr = new Double[list.size()];
            for (int i = 0; i < list.size(); i++) {
                dArr[((Integer) list.get(i).get("index")).intValue()] = (Double) list.get(i).get("relevance_score");
            }
            for (Double d : dArr) {
                arrayList.add(ModelTensor.builder().name("similarity").shape(new long[]{1}).data(new Number[]{d}).dataType(MLResultDataType.FLOAT32).build());
            }
        }
        return arrayList;
    }
}
