package ai.djl.pytorch.jni;

import ai.djl.engine.EngineException;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Platform;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLDecoder;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/pytorch/jni/LibUtils.class */
public final class LibUtils {
    private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);
    private static final String NATIVE_LIB_NAME = System.mapLibraryName("torch");
    private static final String JNI_LIB_NAME = System.mapLibraryName("djl_torch");
    private static final Pattern VERSION_PATTERN = Pattern.compile("(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)(-SNAPSHOT)?(-\\d+)?");
    private static final Pattern LIB_PATTERN = Pattern.compile("(.*\\.(so(\\.\\d+)*|dll|dylib))");
    private static LibTorch libTorch;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/pytorch/jni/LibUtils$LibTorch.class */
    public static final class LibTorch {
        Path dir;
        String version;
        String apiVersion;
        String flavor;
        String classifier;

        LibTorch(Path path) {
            Platform detectPlatform = Platform.detectPlatform("pytorch");
            this.dir = path;
            this.apiVersion = detectPlatform.getApiVersion();
            this.classifier = detectPlatform.getClassifier();
            this.version = Utils.getEnvOrSystemProperty("PYTORCH_VERSION");
            if (this.version == null || this.version.isEmpty()) {
                this.version = detectPlatform.getVersion();
            }
            this.flavor = Utils.getEnvOrSystemProperty("PYTORCH_FLAVOR");
            if (this.flavor == null || this.flavor.isEmpty()) {
                if (CudaUtils.getGpuCount() > 0) {
                    this.flavor = "cu" + CudaUtils.getCudaVersionString() + "-precxx11";
                } else if ("linux".equals(detectPlatform.getOsPrefix())) {
                    this.flavor = "cpu-precxx11";
                } else {
                    this.flavor = "cpu";
                }
            }
        }

        LibTorch(Path path, Platform platform, String str) {
            this.dir = path;
            this.version = platform.getVersion();
            this.apiVersion = platform.getApiVersion();
            this.classifier = platform.getClassifier();
            this.flavor = str;
        }
    }

    private LibUtils() {
    }

    public static synchronized void loadLibrary() {
        if ("http://www.android.com/".equals(System.getProperty("java.vendor.url"))) {
            System.loadLibrary("djl_torch");
            return;
        }
        libTorch = getLibTorch();
        loadLibTorch(libTorch);
        loadNativeLibrary(findJniLibrary(libTorch).toAbsolutePath().toString());
    }

    private static LibTorch getLibTorch() {
        LibTorch findOverrideLibrary = findOverrideLibrary();
        return findOverrideLibrary != null ? findOverrideLibrary : findNativeLibrary();
    }

    public static String getVersion() {
        Matcher matcher = VERSION_PATTERN.matcher(libTorch.version);
        return matcher.matches() ? matcher.group(1) : libTorch.version;
    }

    public static String getLibtorchPath() {
        return libTorch.dir.toString();
    }

    private static void loadLibTorch(LibTorch libTorch2) {
        String envOrSystemProperty;
        Path absolutePath = libTorch2.dir.toAbsolutePath();
        if (Files.exists(absolutePath.resolve("libstdc++.so.6"), new LinkOption[0]) && (envOrSystemProperty = Utils.getEnvOrSystemProperty("LIBSTDCXX_LIBRARY_PATH")) != null) {
            try {
                logger.info("Loading libstdc++.so.6 from: {}", envOrSystemProperty);
                System.load(envOrSystemProperty);
            } catch (UnsatisfiedLinkError e) {
                logger.warn("Failed Loading libstdc++.so.6 from: {}", envOrSystemProperty);
            }
        }
        HashSet hashSet = new HashSet(Arrays.asList(Utils.getEnvOrSystemProperty("PYTORCH_LIBRARY_EXCLUSION", "").split(",")));
        boolean contains = libTorch2.flavor.contains("cu");
        List asList = Arrays.asList(System.mapLibraryName("fbgemm"), System.mapLibraryName("caffe2_nvrtc"), System.mapLibraryName("torch_cpu"), System.mapLibraryName("c10_cuda"), System.mapLibraryName("torch_cuda_cpp"), System.mapLibraryName("torch_cuda_cu"), System.mapLibraryName("torch_cuda"), System.mapLibraryName("nvfuser_codegen"), System.mapLibraryName("torch"));
        HashSet hashSet2 = new HashSet(asList);
        try {
            Stream<Path> walk = Files.walk(absolutePath, new FileVisitOption[0]);
            try {
                ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
                Stream<Path> filter = walk.filter(path -> {
                    String path = path.getFileName().toString();
                    if (!LIB_PATTERN.matcher(path).matches() || hashSet.contains(path)) {
                        return false;
                    }
                    if (!contains && path.contains("nvrtc") && path.contains("cudart") && path.contains("nvTools")) {
                        return false;
                    }
                    if (path.startsWith("libarm_compute-") || path.startsWith("libopenblasp")) {
                        concurrentHashMap.put(path, 2);
                        return true;
                    }
                    if (path.startsWith("libarm_compute_")) {
                        concurrentHashMap.put(path, 3);
                        return true;
                    }
                    if (hashSet2.contains(path) || !Files.isRegularFile(path, new LinkOption[0]) || path.endsWith(JNI_LIB_NAME) || path.contains("torch_") || path.contains("caffe2_") || path.startsWith("cudnn")) {
                        return false;
                    }
                    concurrentHashMap.put(path, 1);
                    return true;
                });
                Objects.requireNonNull(concurrentHashMap);
                filter.sorted(Comparator.comparingInt((v1) -> {
                    return r1.get(v1);
                })).map((v0) -> {
                    return v0.toString();
                }).forEach(LibUtils::loadNativeLibrary);
                if (Files.exists(absolutePath.resolve("cudnn64_8.dll"), new LinkOption[0])) {
                    loadNativeLibrary(absolutePath.resolve("cudnn64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_ops_infer64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_ops_train64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_cnn_infer64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_cnn_train64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_adv_infer64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_adv_train64_8.dll").toString());
                } else if (Files.exists(absolutePath.resolve("cudnn64_7.dll"), new LinkOption[0])) {
                    loadNativeLibrary(absolutePath.resolve("cudnn64_7.dll").toString());
                }
                if (!contains) {
                    asList = Arrays.asList(System.mapLibraryName("fbgemm"), System.mapLibraryName("torch_cpu"), System.mapLibraryName("torch"));
                }
                Iterator it = asList.iterator();
                while (it.hasNext()) {
                    Path resolve = absolutePath.resolve((String) it.next());
                    if (Files.exists(resolve, new LinkOption[0])) {
                        loadNativeLibrary(resolve.toString());
                    }
                }
                if (walk != null) {
                    walk.close();
                }
            } finally {
            }
        } catch (IOException e2) {
            throw new EngineException("Folder not exist! " + absolutePath, e2);
        }
    }

    private static LibTorch findOverrideLibrary() {
        String envOrSystemProperty = Utils.getEnvOrSystemProperty("PYTORCH_LIBRARY_PATH");
        if (envOrSystemProperty != null) {
            return findLibraryInPath(envOrSystemProperty);
        }
        return null;
    }

    private static LibTorch findLibraryInPath(String str) {
        for (String str2 : str.split(File.pathSeparator)) {
            File file = new File(str2);
            if (file.exists()) {
                if (file.isFile() && NATIVE_LIB_NAME.equals(file.getName())) {
                    return new LibTorch(file.getParentFile().toPath().toAbsolutePath());
                }
                File file2 = new File(str2, NATIVE_LIB_NAME);
                if (file2.exists() && file2.isFile()) {
                    return new LibTorch(file.toPath().toAbsolutePath());
                }
            }
        }
        return null;
    }

    private static Path findJniLibrary(LibTorch libTorch2) {
        String str = libTorch2.classifier;
        String str2 = libTorch2.version;
        String str3 = libTorch2.apiVersion;
        String str4 = libTorch2.flavor;
        Path absolutePath = libTorch2.dir.toAbsolutePath();
        Path resolve = absolutePath.resolve(str3 + '-' + JNI_LIB_NAME);
        if (Files.exists(resolve, new LinkOption[0])) {
            return resolve;
        }
        Path resolve2 = absolutePath.resolve(JNI_LIB_NAME);
        if (Files.exists(resolve2, new LinkOption[0])) {
            return resolve2;
        }
        Path resolve3 = Utils.getEngineCacheDir("pytorch").resolve(str2 + '-' + str4 + '-' + str);
        Path resolve4 = resolve3.resolve(str3 + '-' + JNI_LIB_NAME);
        if (Files.exists(resolve4, new LinkOption[0])) {
            return resolve4;
        }
        Matcher matcher = VERSION_PATTERN.matcher(str2);
        if (!matcher.matches()) {
            throw new EngineException("Unexpected version: " + str2);
        }
        String group = matcher.group(1);
        try {
            URL resource = ClassLoaderUtils.getResource("jnilib/pytorch.properties");
            String str5 = null;
            if (resource != null) {
                Properties properties = new Properties();
                InputStream openUrl = Utils.openUrl(resource);
                try {
                    properties.load(openUrl);
                    if (openUrl != null) {
                        openUrl.close();
                    }
                    str5 = properties.getProperty("jni_version");
                    if (str5 == null) {
                        throw new AssertionError("No PyTorch jni version found.");
                    }
                } finally {
                }
            }
            if (str5 == null) {
                downloadJniLib(resolve3, resolve4, str3, group, str, str4);
                return resolve4;
            }
            if (!str5.startsWith(group + '-' + str3)) {
                logger.warn("Found mismatch PyTorch jni: {}", str5);
                downloadJniLib(resolve3, resolve4, str3, group, str, str4);
                return resolve4;
            }
            String str6 = "jnilib/" + str + '/' + str4 + '/' + JNI_LIB_NAME;
            logger.info("Extracting {} to cache ...", str6);
            try {
                try {
                    InputStream resourceAsStream = ClassLoaderUtils.getResourceAsStream(str6);
                    try {
                        Files.createDirectories(resolve3, new FileAttribute[0]);
                        Path createTempFile = Files.createTempFile(resolve3, "jni", "tmp", new FileAttribute[0]);
                        Files.copy(resourceAsStream, createTempFile, StandardCopyOption.REPLACE_EXISTING);
                        Utils.moveQuietly(createTempFile, resolve4);
                        if (resourceAsStream != null) {
                            resourceAsStream.close();
                        }
                        if (createTempFile != null) {
                            Utils.deleteQuietly(createTempFile);
                        }
                        return resolve4;
                    } catch (Throwable th) {
                        if (resourceAsStream != null) {
                            try {
                                resourceAsStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (IOException e) {
                    throw new EngineException("Cannot copy jni files", e);
                }
            } catch (Throwable th3) {
                if (0 != 0) {
                    Utils.deleteQuietly((Path) null);
                }
                throw th3;
            }
        } catch (IOException e2) {
            throw new AssertionError("Failed to read PyTorch jni properties file.", e2);
        }
    }

    private static LibTorch findNativeLibrary() {
        Platform detectPlatform = Platform.detectPlatform("pytorch");
        String envOrSystemProperty = Utils.getEnvOrSystemProperty("PYTORCH_VERSION");
        if (envOrSystemProperty == null || envOrSystemProperty.isEmpty() || detectPlatform.getVersion().startsWith(envOrSystemProperty)) {
            return detectPlatform.isPlaceholder() ? downloadPyTorch(detectPlatform) : copyNativeLibraryFromClasspath(detectPlatform);
        }
        logger.warn("Override PyTorch version: {}.", envOrSystemProperty);
        return downloadPyTorch(Platform.detectPlatform("pytorch", envOrSystemProperty));
    }

    private static LibTorch copyNativeLibraryFromClasspath(Platform platform) {
        logger.debug("Found bundled PyTorch package: {}.", platform);
        String version = platform.getVersion();
        String flavor = platform.getFlavor();
        if (!flavor.endsWith("-precxx11") && Arrays.asList(platform.getLibraries()).contains("libstdc++.so.6")) {
            flavor = flavor + "-precxx11";
        }
        String classifier = platform.getClassifier();
        try {
            try {
                Path engineCacheDir = Utils.getEngineCacheDir("pytorch");
                logger.debug("Using cache dir: {}", engineCacheDir);
                Path resolve = engineCacheDir.resolve(version + '-' + flavor + '-' + classifier);
                if (Files.exists(resolve.resolve(NATIVE_LIB_NAME), new LinkOption[0])) {
                    LibTorch libTorch2 = new LibTorch(resolve.toAbsolutePath(), platform, flavor);
                    if (0 != 0) {
                        Utils.deleteQuietly((Path) null);
                    }
                    return libTorch2;
                }
                Utils.deleteQuietly(resolve);
                if (!VERSION_PATTERN.matcher(version).matches()) {
                    throw new AssertionError("Unexpected version: " + version);
                }
                String str = "pytorch/" + flavor + '/' + classifier;
                Files.createDirectories(engineCacheDir, new FileAttribute[0]);
                Path createTempDirectory = Files.createTempDirectory(engineCacheDir, "tmp", new FileAttribute[0]);
                for (String str2 : platform.getLibraries()) {
                    String str3 = str + '/' + str2;
                    logger.info("Extracting {} to cache ...", str3);
                    InputStream resourceAsStream = ClassLoaderUtils.getResourceAsStream(str3);
                    try {
                        Files.copy(resourceAsStream, createTempDirectory.resolve(str2), StandardCopyOption.REPLACE_EXISTING);
                        if (resourceAsStream != null) {
                            resourceAsStream.close();
                        }
                    } catch (Throwable th) {
                        if (resourceAsStream != null) {
                            try {
                                resourceAsStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                Utils.moveQuietly(createTempDirectory, resolve);
                LibTorch libTorch3 = new LibTorch(resolve.toAbsolutePath(), platform, flavor);
                if (createTempDirectory != null) {
                    Utils.deleteQuietly(createTempDirectory);
                }
                return libTorch3;
            } catch (IOException e) {
                throw new EngineException("Failed to extract PyTorch native library", e);
            }
        } catch (Throwable th3) {
            if (0 != 0) {
                Utils.deleteQuietly((Path) null);
            }
            throw th3;
        }
    }

    private static void loadNativeLibrary(String str) {
        logger.debug("Loading native library: {}", str);
        String property = System.getProperty("ai.djl.pytorch.native_helper");
        if (property == null || property.isEmpty()) {
            System.load(str);
        } else {
            ClassLoaderUtils.nativeLoad(property, str);
        }
    }

    private static LibTorch downloadPyTorch(Platform platform) {
        String str;
        boolean z;
        String version = platform.getVersion();
        String classifier = platform.getClassifier();
        String envOrSystemProperty = Utils.getEnvOrSystemProperty("PYTORCH_FLAVOR");
        if (envOrSystemProperty == null || envOrSystemProperty.isEmpty()) {
            String flavor = platform.getFlavor();
            str = (System.getProperty("os.name").startsWith("Linux") && (Boolean.parseBoolean(Utils.getEnvOrSystemProperty("PYTORCH_PRECXX11")) || "aarch64".equals(platform.getOsArch()))) ? "-precxx11" : "";
            envOrSystemProperty = flavor + str;
            z = false;
        } else {
            logger.info("Uses override PYTORCH_FLAVOR: {}", envOrSystemProperty);
            str = envOrSystemProperty.endsWith("-precxx11") ? "-precxx11" : "";
            z = true;
        }
        Path engineCacheDir = Utils.getEngineCacheDir("pytorch");
        Path resolve = engineCacheDir.resolve(version + '-' + envOrSystemProperty + '-' + classifier);
        if (Files.exists(resolve.resolve(NATIVE_LIB_NAME), new LinkOption[0])) {
            logger.debug("Using cache dir: {}", resolve);
            return new LibTorch(resolve.toAbsolutePath(), platform, envOrSystemProperty);
        }
        Matcher matcher = VERSION_PATTERN.matcher(version);
        if (!matcher.matches()) {
            throw new AssertionError("Unexpected version: " + version);
        }
        String str2 = "https://publish.djl.ai/pytorch/" + matcher.group(1);
        Path resolve2 = engineCacheDir.resolve(version + ".txt");
        if (Files.notExists(resolve2, new LinkOption[0])) {
            Path resolve3 = engineCacheDir.resolve(version + ".tmp");
            try {
                try {
                    InputStream openUrl = Utils.openUrl(str2 + "/files.txt");
                    try {
                        Files.createDirectories(engineCacheDir, new FileAttribute[0]);
                        Files.copy(openUrl, resolve3, StandardCopyOption.REPLACE_EXISTING);
                        Utils.moveQuietly(resolve3, resolve2);
                        if (openUrl != null) {
                            openUrl.close();
                        }
                    } catch (Throwable th) {
                        if (openUrl != null) {
                            try {
                                openUrl.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (IOException e) {
                    throw new EngineException("Failed to save pytorch index file", e);
                }
            } finally {
                Utils.deleteQuietly(resolve3);
            }
        }
        try {
            try {
                InputStream newInputStream = Files.newInputStream(resolve2, new OpenOption[0]);
                try {
                    Files.createDirectories(engineCacheDir, new FileAttribute[0]);
                    List<String> readLines = Utils.readLines(newInputStream);
                    if (envOrSystemProperty.startsWith("cu")) {
                        int parseInt = Integer.parseInt(envOrSystemProperty.substring(2, 5));
                        Pattern compile = Pattern.compile("cu(\\d\\d\\d)" + str + '/' + classifier + "/native/lib/" + NATIVE_LIB_NAME + ".gz");
                        ArrayList arrayList = new ArrayList();
                        boolean z2 = false;
                        Iterator it = readLines.iterator();
                        while (it.hasNext()) {
                            Matcher matcher2 = compile.matcher((String) it.next());
                            if (matcher2.matches()) {
                                arrayList.add(Integer.valueOf(Integer.parseInt(matcher2.group(1))));
                            }
                        }
                        arrayList.sort(Collections.reverseOrder());
                        Iterator it2 = arrayList.iterator();
                        while (true) {
                            if (!it2.hasNext()) {
                                break;
                            }
                            int intValue = ((Integer) it2.next()).intValue();
                            if (z && intValue == parseInt) {
                                z2 = true;
                                break;
                            }
                            if (intValue <= parseInt) {
                                envOrSystemProperty = "cu" + intValue + str;
                                z2 = true;
                                break;
                            }
                        }
                        if (!z2) {
                            logger.warn("No matching cuda flavor for {} found: {}.", classifier, envOrSystemProperty);
                            envOrSystemProperty = "cpu" + str;
                        }
                        resolve = engineCacheDir.resolve(version + '-' + envOrSystemProperty + '-' + classifier);
                        if (Files.exists(resolve.resolve(NATIVE_LIB_NAME), new LinkOption[0])) {
                            LibTorch libTorch2 = new LibTorch(resolve.toAbsolutePath(), platform, envOrSystemProperty);
                            if (newInputStream != null) {
                                newInputStream.close();
                            }
                            return libTorch2;
                        }
                    }
                    logger.debug("Using cache dir: {}", resolve);
                    Path createTempDirectory = Files.createTempDirectory(engineCacheDir, "tmp", new FileAttribute[0]);
                    boolean z3 = false;
                    for (String str3 : readLines) {
                        if (str3.startsWith(envOrSystemProperty + '/' + classifier + '/')) {
                            z3 = true;
                            URL url = new URL(str2 + '/' + str3);
                            String decode = URLDecoder.decode(str3.substring(str3.lastIndexOf(47) + 1, str3.length() - 3), "UTF-8");
                            logger.info("Downloading {} ...", url);
                            GZIPInputStream gZIPInputStream = new GZIPInputStream(Utils.openUrl(url));
                            try {
                                Files.copy(gZIPInputStream, createTempDirectory.resolve(decode), StandardCopyOption.REPLACE_EXISTING);
                                gZIPInputStream.close();
                            } catch (Throwable th3) {
                                try {
                                    gZIPInputStream.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                                throw th3;
                            }
                        }
                    }
                    if (!z3) {
                        throw new EngineException("No PyTorch native library matches your operating system: " + platform);
                    }
                    Utils.moveQuietly(createTempDirectory, resolve);
                    LibTorch libTorch3 = new LibTorch(resolve.toAbsolutePath(), platform, envOrSystemProperty);
                    if (newInputStream != null) {
                        newInputStream.close();
                    }
                    if (createTempDirectory != null) {
                        Utils.deleteQuietly(createTempDirectory);
                    }
                    return libTorch3;
                } catch (Throwable th5) {
                    if (newInputStream != null) {
                        try {
                            newInputStream.close();
                        } catch (Throwable th6) {
                            th5.addSuppressed(th6);
                        }
                    }
                    throw th5;
                }
            } catch (IOException e2) {
                throw new EngineException("Failed to download PyTorch native library", e2);
            }
        } finally {
            if (0 != 0) {
                Utils.deleteQuietly((Path) null);
            }
        }
    }

    private static void downloadJniLib(Path path, Path path2, String str, String str2, String str3, String str4) {
        String str5 = "https://publish.djl.ai/pytorch/" + str2 + "/jnilib/" + str + '/' + str3 + '/' + str4 + '/' + JNI_LIB_NAME;
        logger.info("Downloading jni {} to cache ...", str5);
        try {
            try {
                InputStream openUrl = Utils.openUrl(str5);
                try {
                    Files.createDirectories(path, new FileAttribute[0]);
                    Path createTempFile = Files.createTempFile(path, "jni", "tmp", new FileAttribute[0]);
                    Files.copy(openUrl, createTempFile, StandardCopyOption.REPLACE_EXISTING);
                    Utils.moveQuietly(createTempFile, path2);
                    if (openUrl != null) {
                        openUrl.close();
                    }
                    if (createTempFile != null) {
                        Utils.deleteQuietly(createTempFile);
                    }
                } catch (Throwable th) {
                    if (openUrl != null) {
                        try {
                            openUrl.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (IOException e) {
                throw new EngineException("Cannot download jni files: " + str5, e);
            }
        } catch (Throwable th3) {
            if (0 != 0) {
                Utils.deleteQuietly((Path) null);
            }
            throw th3;
        }
    }
}
