package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.Pair;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/SafeTensorSplitter.class */
public class SafeTensorSplitter {
    static long MAX_CHUNK_SIZE = 21474836480L;

    static String getChunkFile(TensorInfo tensorInfo, long j) {
        return String.format("model-%05d-of-%05d.safetensor", Long.valueOf(Math.floorDiv(tensorInfo.dataOffsets[1], MAX_CHUNK_SIZE)), Long.valueOf(Math.floorDiv(j, MAX_CHUNK_SIZE)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void main(String[] strArr) {
        RandomAccessFile randomAccessFile;
        if (strArr.length == 0) {
            throw new IllegalArgumentException("Missing model name");
        }
        String str = strArr[0];
        if (!new File(str).isDirectory()) {
            throw new IllegalArgumentException("Not a directory");
        }
        if (Paths.get(str, SafeTensorIndex.MODEL_INDEX_JSON).toFile().exists()) {
            throw new IllegalArgumentException("Already split");
        }
        if (!Paths.get(str, SafeTensorIndex.SINGLE_MODEL_NAME).toFile().exists()) {
            throw new IllegalArgumentException("Missing model file");
        }
        WeightLoader loadWeights = SafeTensorSupport.loadWeights(new File(str));
        try {
            Map<String, TensorInfo> tensorInfoMap = loadWeights.tensorInfoMap();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            HashMap hashMap = new HashMap();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            for (Map.Entry<String, TensorInfo> entry : tensorInfoMap.entrySet()) {
                TensorInfo value = entry.getValue();
                String key = entry.getKey();
                String chunkFile = getChunkFile(value, new File(str, SafeTensorIndex.SINGLE_MODEL_NAME).length());
                linkedHashMap.put(key, chunkFile);
                TensorInfo save = loadWeights.load(key).save((FileChannel) ((Pair) hashMap.computeIfAbsent(chunkFile, str2 -> {
                    try {
                        File createTempFile = File.createTempFile("jlama", "chunk");
                        createTempFile.deleteOnExit();
                        RandomAccessFile randomAccessFile2 = new RandomAccessFile(createTempFile, "rw");
                        return Pair.of(randomAccessFile2, randomAccessFile2.getChannel());
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                })).right);
                PrintStream printStream = System.out;
                long j = save.dataOffsets[0];
                long j2 = save.dataOffsets[1];
                printStream.println("Wrote " + key + " to " + chunkFile + " at " + j + " to " + printStream);
                ((Map) linkedHashMap2.computeIfAbsent(chunkFile, str3 -> {
                    return new LinkedHashMap();
                })).put(key, save);
            }
            for (Map.Entry entry2 : hashMap.entrySet()) {
                String str4 = (String) entry2.getKey();
                FileChannel channel = ((RandomAccessFile) ((Pair) entry2.getValue()).left).getChannel();
                Map map = (Map) linkedHashMap2.get(str4);
                byte[] writeValueAsBytes = JsonSupport.om.writeValueAsBytes(map);
                System.out.println("Writing " + str4 + " with " + map.size() + " tensors");
                byte[] bArr = new byte[8];
                ByteBuffer.wrap(bArr).order(ByteOrder.LITTLE_ENDIAN).putLong(writeValueAsBytes.length);
                randomAccessFile = new RandomAccessFile(Paths.get(str, str4).toFile(), "rw");
                try {
                    randomAccessFile.write(bArr);
                    randomAccessFile.write(writeValueAsBytes);
                    randomAccessFile.seek(randomAccessFile.length());
                    PrintStream printStream2 = System.out;
                    long size = channel.size();
                    randomAccessFile.getChannel().position();
                    printStream2.println("Writing " + size + " bytes of data from " + printStream2);
                    channel.transferTo(0L, channel.size(), randomAccessFile.getChannel());
                    randomAccessFile.close();
                } finally {
                }
            }
            randomAccessFile = new RandomAccessFile(Paths.get(str, SafeTensorIndex.MODEL_INDEX_JSON).toFile(), "rw");
            try {
                randomAccessFile.write(JsonSupport.om.writeValueAsBytes(Map.of("metadata", new HashMap(), "weight_map", linkedHashMap)));
                randomAccessFile.close();
                Iterator it = hashMap.values().iterator();
                while (it.hasNext()) {
                    ((RandomAccessFile) ((Pair) it.next()).left).close();
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
