package org.nd4j.linalg.cpu.nativecpu.cache;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.cache.ArrayDescriptor;
import org.nd4j.linalg.cache.BasicConstantHandler;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.class */
public class ConstantBuffersCache extends BasicConstantHandler {
    protected Map<ArrayDescriptor, DataBuffer> buffersCache = new ConcurrentHashMap();
    private AtomicInteger counter = new AtomicInteger(0);
    private AtomicLong bytes = new AtomicLong(0);
    private static final int MAX_ENTRIES = 1000;

    public DataBuffer getConstantBuffer(int[] iArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(iArr);
        if (this.buffersCache.containsKey(arrayDescriptor)) {
            return this.buffersCache.get(arrayDescriptor);
        }
        DataBuffer createBufferDetached = Nd4j.createBufferDetached(iArr);
        if (this.counter.get() < MAX_ENTRIES || iArr.length < 4) {
            this.counter.incrementAndGet();
            this.buffersCache.put(arrayDescriptor, createBufferDetached);
            this.bytes.addAndGet(iArr.length * 4);
        }
        return createBufferDetached;
    }

    public void purgeConstants() {
        this.buffersCache = new ConcurrentHashMap();
    }

    public DataBuffer getConstantBuffer(float[] fArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(fArr);
        if (this.buffersCache.containsKey(arrayDescriptor)) {
            return this.buffersCache.get(arrayDescriptor);
        }
        DataBuffer createBufferDetached = Nd4j.createBufferDetached(fArr);
        if (this.counter.get() < MAX_ENTRIES) {
            this.counter.incrementAndGet();
            this.buffersCache.put(arrayDescriptor, createBufferDetached);
            this.bytes.addAndGet(fArr.length * Nd4j.sizeOfDataType());
        }
        return createBufferDetached;
    }

    public DataBuffer getConstantBuffer(double[] dArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(dArr);
        if (this.buffersCache.containsKey(arrayDescriptor)) {
            return this.buffersCache.get(arrayDescriptor);
        }
        DataBuffer createBufferDetached = Nd4j.createBufferDetached(dArr);
        if (this.counter.get() < MAX_ENTRIES) {
            this.counter.incrementAndGet();
            this.buffersCache.put(arrayDescriptor, createBufferDetached);
            this.bytes.addAndGet(dArr.length * Nd4j.sizeOfDataType());
        }
        return createBufferDetached;
    }

    public long getCachedBytes() {
        return this.bytes.get();
    }
}
