package shap4j;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import shap4j.shap.ExplanationDataset;
import shap4j.shap.TreeEnsemble;
import shap4j.shap.TreeShap;

/* loaded from: input_file:shap4j/TreeExplainer.class */
public class TreeExplainer {
    private static final int TREE_PATH_DEPENDENT_FEATURE = 1;
    private static final int IDENTITY_TRANSFORM = 0;
    private TreeEnsemble treeEnsemble;
    static final /* synthetic */ boolean $assertionsDisabled;

    private TreeExplainer(TreeEnsemble treeEnsemble) {
        this.treeEnsemble = treeEnsemble;
    }

    public TreeExplainer(byte[] bArr) {
        this.treeEnsemble = TreeEnsemble.fromBytes(bArr);
    }

    public double[][] shapValues(ExplanationDataset explanationDataset, boolean z) {
        int numRows = explanationDataset.getNumRows();
        int numCols = explanationDataset.getNumCols() + TREE_PATH_DEPENDENT_FEATURE;
        DoublePointer doublePointer = new DoublePointer(numRows * numCols);
        BytePointer.memset(doublePointer, IDENTITY_TRANSFORM, numRows * numCols * 8);
        if (z) {
            TreeShap.dense_tree_saabas(doublePointer, this.treeEnsemble, explanationDataset);
        } else {
            TreeShap.dense_tree_shap(this.treeEnsemble, explanationDataset, doublePointer, TREE_PATH_DEPENDENT_FEATURE, IDENTITY_TRANSFORM, false);
        }
        double[][] dArr = new double[numRows][numCols - TREE_PATH_DEPENDENT_FEATURE];
        int i = IDENTITY_TRANSFORM;
        int i2 = IDENTITY_TRANSFORM;
        while (i2 < numRows) {
            doublePointer.position(i).limit(i + numCols).asBuffer().get(dArr[i2], IDENTITY_TRANSFORM, numCols - TREE_PATH_DEPENDENT_FEATURE);
            i2 += TREE_PATH_DEPENDENT_FEATURE;
            i += numCols;
        }
        explanationDataset.close();
        doublePointer.close();
        return dArr;
    }

    public double[][] shapValues(ExplanationDataset explanationDataset) {
        return shapValues(explanationDataset, false);
    }

    public double[][] shapValues(double[][] dArr, boolean z, boolean z2) {
        if (!$assertionsDisabled && dArr.length <= 0) {
            throw new AssertionError();
        }
        if (this.treeEnsemble.num_outputs() != TREE_PATH_DEPENDENT_FEATURE) {
            throw new IllegalArgumentException("Currently only supporting models with num_outputs == 1");
        }
        return shapValues(ExplanationDataset.fromMatrix(dArr, z2), z);
    }

    public double[][] shapValues(double[][] dArr, boolean z) {
        return shapValues(dArr, false, z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    public double[] shapValues(double[] dArr, boolean z, boolean z2) {
        return shapValues((double[][]) new double[]{dArr}, z, z2)[IDENTITY_TRANSFORM];
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    public double[] shapValues(double[] dArr, boolean z) {
        return shapValues((double[][]) new double[]{dArr}, false, z)[IDENTITY_TRANSFORM];
    }

    public static TreeExplainer fromResource(String str) throws IOException {
        try {
            InputStream resourceAsStream = TreeExplainer.class.getResourceAsStream(str);
            try {
                int available = resourceAsStream.available();
                int i = IDENTITY_TRANSFORM;
                byte[] bArr = new byte[available];
                while (available > 0) {
                    int read = resourceAsStream.read(bArr, i, available);
                    if (read == -1) {
                        available = IDENTITY_TRANSFORM;
                    } else {
                        available -= read;
                        i += read;
                    }
                }
                TreeExplainer treeExplainer = new TreeExplainer(bArr);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
                return treeExplainer;
            } finally {
            }
        } catch (NullPointerException e) {
            return null;
        }
    }

    public static TreeExplainer fromFile(String str) throws IOException {
        return new TreeExplainer(Files.readAllBytes(new File(str).toPath()));
    }

    static {
        $assertionsDisabled = !TreeExplainer.class.desiredAssertionStatus();
    }
}
