package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import java.io.Closeable;
import java.io.IOError;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.Iterator;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/tensor/KvBufferCache.class */
public class KvBufferCache implements Closeable {
    private static final Logger logger = LoggerFactory.getLogger(KvBufferCache.class);
    private final ConcurrentMap<UUID, KvBuffer> kvBufferCache = new ConcurrentHashMap();
    private final AbstractModel model;

    /* loaded from: input_file:com/github/tjake/jlama/tensor/KvBufferCache$KvBuffer.class */
    public class KvBuffer implements AutoCloseable {
        private UUID session;
        private final AtomicInteger currentContextPosition = new AtomicInteger(0);
        private final KvBufferPage[][] pages;
        private final KvPageContext pageContext;
        private final boolean ephemeral;

        KvBuffer(UUID uuid, int i, boolean z) {
            this.session = uuid;
            this.pageContext = computePageSize(i);
            this.pages = new KvBufferPage[this.pageContext.numberOfLayerPages][this.pageContext.numberOfContextPages];
            this.ephemeral = z;
        }

        public int getCurrentContextPosition() {
            return this.currentContextPosition.get();
        }

        public void setCurrentContextPosition(int i) {
            this.currentContextPosition.set(i);
        }

        public void incrementContextPosition() {
            this.currentContextPosition.incrementAndGet();
        }

        public KvPageContext computePageSize(long j) {
            Config config = KvBufferCache.this.model.getConfig();
            long size = 2 * KvBufferCache.this.model.getWorkingDType().size() * config.dctx().kvSegmentLength;
            Preconditions.checkArgument(j > size, "maxPageSizeInBytes must be greater than the size of a single layer");
            int i = config.dctx().numberOfLayers;
            int i2 = config.contextLength;
            int i3 = 1;
            int i4 = 1;
            long j2 = 0;
            for (int i5 = i; i5 >= 1; i5--) {
                long j3 = j / (i5 * size);
                if (j3 >= 1 && j3 <= i2) {
                    long j4 = i5 * j3;
                    if (j4 > j2) {
                        i3 = i5;
                        i4 = (int) j3;
                        j2 = j4;
                    }
                    if (j4 < j2) {
                        break;
                    }
                }
            }
            int ceil = (int) Math.ceil(i / i3);
            int ceil2 = (int) Math.ceil(i2 / i4);
            long j5 = i3 * i4 * size;
            if (j5 > j) {
                IllegalArgumentException illegalArgumentException = new IllegalArgumentException("Calculation error: pageSize > maxPageSizeInBytes: " + j5 + " > " + illegalArgumentException);
                throw illegalArgumentException;
            }
            KvBufferCache.logger.debug("Optimal page size: {} layers, {} context length, {} bytes, {} layer pages, {} length pages", new Object[]{Integer.valueOf(i3), Integer.valueOf(i4), Long.valueOf(j5), Integer.valueOf(ceil), Integer.valueOf(ceil2)});
            return new KvPageContext(KvBufferCache.this, this.session, ceil, ceil2, i3, i4);
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            for (KvBufferPage[] kvBufferPageArr : this.pages) {
                if (kvBufferPageArr != null) {
                    for (KvBufferPage kvBufferPage : kvBufferPageArr) {
                        if (kvBufferPage != null) {
                            try {
                                kvBufferPage.close();
                            } catch (IOException e) {
                                KvBufferCache.logger.debug("Error closing page", e);
                            }
                        }
                    }
                }
            }
        }

        public AbstractTensor getKeyTensorForPosition(int i, int i2) {
            return getTensorForPosition(i, i2, 0);
        }

        public AbstractTensor getValTensorForPosition(int i, int i2) {
            return getTensorForPosition(i, i2, 1);
        }

        private AbstractTensor getTensorForPosition(int i, int i2, int i3) {
            int i4 = i / this.pageContext.layersPerPage;
            int i5 = i2 / this.pageContext.contextLengthPerPage;
            int i6 = i % this.pageContext.layersPerPage;
            int i7 = i2 % this.pageContext.contextLengthPerPage;
            KvBufferPage kvBufferPage = this.pages[i4][i5];
            if (kvBufferPage == null || kvBufferPage.isClosed()) {
                kvBufferPage = new KvBufferPage(KvBufferCache.this, this.pageContext, "L" + i4 + "C" + i5, this.ephemeral);
                this.pages[i4][i5] = kvBufferPage;
            }
            return kvBufferPage.getTensor().slice(true, i6, i3, i7);
        }

        public AbstractTensor[] getKeyTensorsUptoPosition(int i, int i2) {
            return getTensorsUptoPosition(i, 0, i2);
        }

        public AbstractTensor[] getValTensorsUptoPosition(int i, int i2) {
            return getTensorsUptoPosition(i, 1, i2);
        }

        private AbstractTensor[] getTensorsUptoPosition(int i, int i2, int i3) {
            int i4 = i / this.pageContext.layersPerPage;
            int i5 = i3 / this.pageContext.contextLengthPerPage;
            int i6 = i % this.pageContext.layersPerPage;
            KvBufferPage[] kvBufferPageArr = this.pages[i4];
            AbstractTensor[] abstractTensorArr = new AbstractTensor[i5 + 1];
            for (int i7 = 0; i7 <= i5; i7++) {
                KvBufferPage kvBufferPage = kvBufferPageArr[i7];
                if (kvBufferPage == null || kvBufferPage.isClosed()) {
                    kvBufferPage = new KvBufferPage(KvBufferCache.this, this.pageContext, "L" + i4 + "C" + i5, this.ephemeral);
                    kvBufferPageArr[i7] = kvBufferPage;
                }
                abstractTensorArr[i7] = kvBufferPage.getTensor().slice(true, i6, i2);
            }
            return abstractTensorArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/github/tjake/jlama/tensor/KvBufferCache$KvBufferPage.class */
    public class KvBufferPage implements AutoCloseable {
        private final AbstractTensor tensor;
        private final KvPageContext pageCtx;
        private final String pageId;
        private final AtomicBoolean closed = new AtomicBoolean(false);
        private final RandomAccessFile raf;
        static final /* synthetic */ boolean $assertionsDisabled;

        KvBufferPage(KvBufferCache kvBufferCache, KvPageContext kvPageContext, String str, boolean z) {
            AbstractTensor bFloat16BufferTensor;
            this.pageCtx = kvPageContext;
            this.pageId = str;
            if (kvBufferCache.model.getConfig().workingDirectory().isEmpty() || z) {
                this.raf = null;
                this.tensor = TensorCache.instance.get(kvBufferCache.model.getWorkingDType(), kvPageContext.pageShape);
                return;
            }
            try {
                this.raf = new RandomAccessFile(Paths.get(kvBufferCache.model.getConfig().workingDirectory().get().toString(), kvPageContext.session.toString() + "-" + str + ".page").toFile(), "rw");
                long size = kvPageContext.pageShape.size() * kvBufferCache.model.getWorkingDType().size();
                KvBufferCache.logger.debug("Allocating page {} with {} bytes {}", new Object[]{str, Long.valueOf(size), Long.valueOf(this.raf.length())});
                if (this.raf.length() != size) {
                    this.raf.setLength(size);
                }
                if (kvBufferCache.model.getWorkingDType() == DType.F32) {
                    bFloat16BufferTensor = new FloatBufferTensor(this.raf.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, size).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(), kvPageContext.pageShape, true);
                } else {
                    if (kvBufferCache.model.getWorkingDType() != DType.BF16) {
                        throw new UnsupportedOperationException("Only F32/BF16 is supported for now");
                    }
                    bFloat16BufferTensor = new BFloat16BufferTensor("kvmem", this.raf.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, size).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(), kvPageContext.pageShape, true);
                }
                this.tensor = bFloat16BufferTensor;
            } catch (IOException e) {
                throw new IOError(e);
            }
        }

        public AbstractTensor getTensor() {
            if ($assertionsDisabled || !this.closed.get()) {
                return this.tensor;
            }
            throw new AssertionError("Page is closed");
        }

        public boolean isClosed() {
            return this.closed.get();
        }

        @Override // java.lang.AutoCloseable
        public void close() throws IOException {
            if (this.closed.compareAndSet(false, true)) {
                if (this.raf != null) {
                    this.raf.close();
                }
                this.tensor.close();
            }
        }

        static {
            $assertionsDisabled = !KvBufferCache.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/github/tjake/jlama/tensor/KvBufferCache$KvPageContext.class */
    public class KvPageContext {
        public final int numberOfLayerPages;
        public final int numberOfContextPages;
        private final int layersPerPage;
        private final int contextLengthPerPage;
        private final UUID session;
        public final TensorShape pageShape;

        public KvPageContext(KvBufferCache kvBufferCache, UUID uuid, int i, int i2, int i3, int i4) {
            this.session = uuid;
            this.numberOfLayerPages = i;
            this.numberOfContextPages = i2;
            this.layersPerPage = i3;
            this.contextLengthPerPage = i4;
            if (i < 1) {
                throw new IllegalArgumentException("totalPageCount must be >= 1");
            }
            if (i2 < 1) {
                throw new IllegalArgumentException("numberOfContextPages must be >= 1");
            }
            if (i3 < 1) {
                throw new IllegalArgumentException("layersPerPage must be >= 1");
            }
            if (i4 < 1) {
                throw new IllegalArgumentException("contextLengthPerPage must be >= 1");
            }
            Config config = kvBufferCache.model.getConfig();
            DistributedContext dctx = config.dctx();
            int[] iArr = {i3, 2, i4, config.kvLength};
            this.pageShape = config.kvLength != dctx.kvSegmentLength ? TensorShape.sparseColumn(iArr, Pair.of(Integer.valueOf(dctx.kvSegmentStart), Integer.valueOf(dctx.kvSegmentEnd))) : TensorShape.of(iArr);
        }
    }

    public KvBufferCache(AbstractModel abstractModel) {
        this.model = abstractModel;
    }

    public KvBuffer getKvBuffer(UUID uuid) {
        return this.kvBufferCache.computeIfAbsent(uuid, uuid2 -> {
            return new KvBuffer(uuid2, 8388608, false);
        });
    }

    public KvBuffer getEphemeralKvBuffer() {
        return new KvBuffer(UUID.randomUUID(), 1048576, true);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        Iterator<Map.Entry<UUID, KvBuffer>> it = this.kvBufferCache.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().close();
            it.remove();
        }
    }
}
