package com.github.tjake.jlama.safetensors;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.SegmentedTensor;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/SafeTensorIndex.class */
public class SafeTensorIndex implements WeightLoader, AutoCloseable {
    private static final Logger logger;
    private static final ObjectMapper om;
    public static final String SINGLE_MODEL_NAME = "model.safetensors";
    public static final String MODEL_INDEX_JSON = "model.safetensors.index.json";
    private final Map<String, String> metadata;
    final Map<String, String> weightFileMap;
    static final /* synthetic */ boolean $assertionsDisabled;
    final Map<String, TensorInfo> allTensorInfoMap = new HashMap();
    private final Map<String, Weights> weightMap = new HashMap();
    private final Map<String, RandomAccessFile> fileMap = new HashMap();

    public static SafeTensorIndex loadWithWeights(Path path) {
        try {
            SafeTensorIndex safeTensorIndex = (SafeTensorIndex) om.readValue(Paths.get(path.toString(), MODEL_INDEX_JSON).toFile(), SafeTensorIndex.class);
            loadWeights(safeTensorIndex, path);
            return safeTensorIndex;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static SafeTensorIndex loadSingleFile(Path path, String str) {
        try {
            SafeTensorIndex safeTensorIndex = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", str));
            loadWeights(safeTensorIndex, path);
            return safeTensorIndex;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static void loadWeights(SafeTensorIndex safeTensorIndex, Path path) throws IOException {
        for (Map.Entry<String, String> entry : safeTensorIndex.weightFileMap.entrySet()) {
            if (!safeTensorIndex.fileMap.containsKey(entry.getValue())) {
                RandomAccessFile randomAccessFile = new RandomAccessFile(Paths.get(path.toString(), entry.getValue()).toFile(), "r");
                safeTensorIndex.fileMap.put(entry.getValue(), randomAccessFile);
                MappedByteBuffer map = randomAccessFile.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, Math.min(1048576L, randomAccessFile.length()));
                HashMap hashMap = new HashMap();
                Map<String, TensorInfo> readTensorInfoMap = SafeTensorSupport.readTensorInfoMap(map, Optional.of(hashMap));
                safeTensorIndex.allTensorInfoMap.putAll(readTensorInfoMap);
                int position = map.position();
                for (Map.Entry<List<Long>, List<String>> entry2 : safeTensorIndex.computeMmapSplits(readTensorInfoMap, randomAccessFile.length()).entrySet()) {
                    long longValue = entry2.getKey().get(0).longValue();
                    long longValue2 = entry2.getKey().get(1).longValue();
                    List<String> value = entry2.getValue();
                    Weights weights = new Weights(hashMap, (Map) readTensorInfoMap.entrySet().stream().filter(entry3 -> {
                        return value.contains(entry3.getKey());
                    }).collect(ImmutableMap.toImmutableMap((v0) -> {
                        return v0.getKey();
                    }, (v0) -> {
                        return v0.getValue();
                    })), randomAccessFile.getChannel().map(FileChannel.MapMode.READ_ONLY, position + longValue, Ints.checkedCast(longValue2 - longValue)), Optional.of(safeTensorIndex));
                    Iterator<String> it = value.iterator();
                    while (it.hasNext()) {
                        safeTensorIndex.weightMap.put(it.next(), weights);
                    }
                }
            }
        }
    }

    private Map<List<Long>, List<String>> computeMmapSplits(Map<String, TensorInfo> map, long j) {
        HashMap hashMap = new HashMap();
        long j2 = 0;
        int size = map.size();
        int i = 0;
        ArrayList arrayList = new ArrayList();
        Iterator it = new ArrayList(map.entrySet()).iterator();
        Map.Entry entry = null;
        while (i < size && (it.hasNext() || entry != null)) {
            arrayList.clear();
            long j3 = j2 + 2147483647L;
            long j4 = j;
            long j5 = 0;
            while (true) {
                if (!it.hasNext() && entry == null) {
                    break;
                }
                entry = entry == null ? (Map.Entry) it.next() : entry;
                TensorInfo tensorInfo = (TensorInfo) entry.getValue();
                logger.debug("Tensor {} {} {} limit {}", new Object[]{entry.getKey(), Long.valueOf(tensorInfo.dataOffsets[0]), Long.valueOf(tensorInfo.dataOffsets[1]), Long.valueOf(j3)});
                if (tensorInfo.dataOffsets[1] < j3) {
                    arrayList.add((String) entry.getKey());
                    i++;
                    if (tensorInfo.dataOffsets[1] > j5) {
                        j5 = tensorInfo.dataOffsets[1];
                    }
                    if (tensorInfo.dataOffsets[0] < j4) {
                        j4 = tensorInfo.dataOffsets[0];
                    }
                    long[] jArr = tensorInfo.dataOffsets;
                    jArr[0] = jArr[0] - j2;
                    long[] jArr2 = tensorInfo.dataOffsets;
                    jArr2[1] = jArr2[1] - j2;
                    logger.debug("Adding tensor {} to split {}-{}", new Object[]{entry.getKey(), Long.valueOf(tensorInfo.dataOffsets[0]), Long.valueOf(tensorInfo.dataOffsets[1])});
                    entry = null;
                } else if (arrayList.size() == 0) {
                    int size2 = tensorInfo.dType.size() * tensorInfo.shape[1];
                    if (tensorInfo.dataOffsets[1] > j5) {
                        j5 = tensorInfo.dataOffsets[1];
                    }
                    if (tensorInfo.dataOffsets[0] < j4) {
                        j4 = tensorInfo.dataOffsets[0];
                    }
                    long[] jArr3 = tensorInfo.dataOffsets;
                    jArr3[0] = jArr3[0] - j2;
                    long[] jArr4 = tensorInfo.dataOffsets;
                    jArr4[1] = jArr4[1] - j2;
                    long j6 = tensorInfo.dataOffsets[0];
                    long j7 = Integer.MAX_VALUE - (Integer.MAX_VALUE % size2);
                    long j8 = 0;
                    int i2 = 0;
                    boolean z = false;
                    for (long j9 = tensorInfo.dataOffsets[1] - j6; j9 > 0; j9 -= j7) {
                        long min = Math.min(j6 + j7, j5);
                        int i3 = i2;
                        i2++;
                        String str = ((String) entry.getKey()) + "-part-" + i3;
                        logger.debug("Adding chunk {} to split {}-{} {}", new Object[]{str, Long.valueOf(j6), Long.valueOf(min), Integer.valueOf(Ints.checkedCast(min - j6))});
                        hashMap.put(List.of(Long.valueOf(j6), Long.valueOf(min)), List.of(str));
                        if (!$assertionsDisabled && tensorInfo.shape.length != 2) {
                            throw new AssertionError("Only 2D tensors supported");
                        }
                        map.put(str, new TensorInfo(tensorInfo.dType, new long[]{Ints.checkedCast((min - j6) / size2), tensorInfo.shape[1]}, new long[]{j6 - j8, min - j8}));
                        z = true;
                        j8 += min - j6;
                        j6 = min;
                    }
                    if (z) {
                        i++;
                        entry = null;
                    }
                }
            }
            if (!$assertionsDisabled && i <= 0) {
                throw new AssertionError("No tensors in split");
            }
            logger.debug("Adding split {}-{} with {} tensors of {}", new Object[]{Long.valueOf(j4), Long.valueOf(j5), Integer.valueOf(arrayList.size()), Integer.valueOf(i)});
            if (!arrayList.isEmpty()) {
                hashMap.put(List.of(Long.valueOf(j4), Long.valueOf(j5)), new ArrayList(arrayList));
            }
            if (j5 > j2) {
                j2 = j5;
            }
        }
        if ($assertionsDisabled || size == i) {
            return hashMap;
        }
        throw new AssertionError("Not all tensors were split: " + i + " != " + size);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @JsonCreator
    public SafeTensorIndex(@JsonProperty("metadata") Map<String, String> map, @JsonProperty("weight_map") Map<String, String> map2) {
        this.metadata = ImmutableMap.copyOf(map);
        this.weightFileMap = ImmutableMap.copyOf(map2);
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, String> metadata() {
        return this.metadata;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, TensorInfo> tensorInfoMap() {
        return this.allTensorInfoMap;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public AbstractTensor load(String str, DistributedContext distributedContext, boolean z, boolean z2) {
        Weights weights = this.weightMap.get(str);
        if (weights != null) {
            return weights.load(str, distributedContext, z, z2);
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            String str2 = str + "-part-" + i2;
            if (!this.weightMap.containsKey(str2)) {
                break;
            }
            arrayList.add(this.weightMap.get(str2).load(str2, distributedContext, z, z2));
        }
        if (arrayList.size() > 0) {
            return SegmentedTensor.wrap(arrayList);
        }
        throw new NoSuchElementException(str);
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public DType getModelDType() {
        return this.weightMap.values().iterator().next().getModelDType();
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        this.weightMap.clear();
        this.fileMap.forEach((str, randomAccessFile) -> {
            try {
                randomAccessFile.close();
            } catch (IOException e) {
            }
        });
        this.fileMap.clear();
        this.allTensorInfoMap.clear();
    }

    static {
        $assertionsDisabled = !SafeTensorIndex.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(SafeTensorIndex.class);
        om = new ObjectMapper();
    }
}
