package io.trino.orc.stream;

import com.google.common.base.MoreObjects;
import io.airlift.slice.FixedLengthSliceInput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.orc.OrcCorruptionException;
import io.trino.orc.OrcDataSourceId;
import io.trino.orc.OrcDecompressor;
import io.trino.orc.checkpoint.InputStreamCheckpoint;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/* loaded from: input_file:io/trino/orc/stream/CompressedOrcChunkLoader.class */
public final class CompressedOrcChunkLoader implements OrcChunkLoader {
    private final OrcDataReader dataReader;
    private final LocalMemoryContext dataReaderMemoryUsage;
    private final OrcDecompressor decompressor;
    private final LocalMemoryContext decompressionBufferMemoryUsage;
    private FixedLengthSliceInput compressedBufferStream = Slices.EMPTY_SLICE.getInput();
    private int compressedBufferStart;
    private int nextUncompressedOffset;
    private long lastCheckpoint;
    private byte[] decompressorOutputBuffer;

    public CompressedOrcChunkLoader(OrcDataReader orcDataReader, OrcDecompressor orcDecompressor, AggregatedMemoryContext aggregatedMemoryContext) {
        this.dataReader = (OrcDataReader) Objects.requireNonNull(orcDataReader, "dataReader is null");
        this.decompressor = (OrcDecompressor) Objects.requireNonNull(orcDecompressor, "decompressor is null");
        Objects.requireNonNull(aggregatedMemoryContext, "memoryContext is null");
        this.dataReaderMemoryUsage = aggregatedMemoryContext.newLocalMemoryContext(CompressedOrcChunkLoader.class.getSimpleName());
        this.dataReaderMemoryUsage.setBytes(orcDataReader.getRetainedSize());
        this.decompressionBufferMemoryUsage = aggregatedMemoryContext.newLocalMemoryContext(CompressedOrcChunkLoader.class.getSimpleName());
    }

    @Override // io.trino.orc.stream.OrcChunkLoader
    public OrcDataSourceId getOrcDataSourceId() {
        return this.dataReader.getOrcDataSourceId();
    }

    private int getCurrentCompressedOffset() {
        return Math.toIntExact(this.compressedBufferStart + this.compressedBufferStream.position());
    }

    @Override // io.trino.orc.stream.OrcChunkLoader
    public boolean hasNextChunk() {
        return getCurrentCompressedOffset() < this.dataReader.getSize();
    }

    @Override // io.trino.orc.stream.OrcChunkLoader
    public long getLastCheckpoint() {
        return this.lastCheckpoint;
    }

    @Override // io.trino.orc.stream.OrcChunkLoader
    public void seekToCheckpoint(long j) throws IOException {
        int decodeCompressedBlockOffset = InputStreamCheckpoint.decodeCompressedBlockOffset(j);
        if (decodeCompressedBlockOffset >= this.dataReader.getSize()) {
            throw new OrcCorruptionException(this.dataReader.getOrcDataSourceId(), "Seek past end of stream");
        }
        if (this.compressedBufferStart > decodeCompressedBlockOffset || decodeCompressedBlockOffset >= this.compressedBufferStart + this.compressedBufferStream.length()) {
            this.compressedBufferStart = decodeCompressedBlockOffset;
            this.compressedBufferStream = Slices.EMPTY_SLICE.getInput();
        } else {
            this.compressedBufferStream.setPosition(decodeCompressedBlockOffset - this.compressedBufferStart);
        }
        this.nextUncompressedOffset = InputStreamCheckpoint.decodeDecompressedOffset(j);
        this.lastCheckpoint = j;
    }

    @Override // io.trino.orc.stream.OrcChunkLoader
    public Slice nextChunk() throws IOException {
        ensureCompressedBytesAvailable(3);
        this.lastCheckpoint = InputStreamCheckpoint.createInputStreamCheckpoint(getCurrentCompressedOffset(), this.nextUncompressedOffset);
        int readUnsignedByte = this.compressedBufferStream.readUnsignedByte();
        int readUnsignedByte2 = this.compressedBufferStream.readUnsignedByte();
        int readUnsignedByte3 = this.compressedBufferStream.readUnsignedByte();
        boolean z = (readUnsignedByte & 1) == 1;
        int i = (readUnsignedByte3 << 15) | (readUnsignedByte2 << 7) | (readUnsignedByte >>> 1);
        ensureCompressedBytesAvailable(i);
        Slice readSlice = this.compressedBufferStream.readSlice(i);
        if (!z) {
            readSlice = Slices.wrappedBuffer(this.decompressorOutputBuffer, 0, this.decompressor.decompress(readSlice.byteArray(), readSlice.byteArrayOffset(), readSlice.length(), createOutputBuffer()));
        }
        if (this.nextUncompressedOffset != 0) {
            readSlice = readSlice.slice(this.nextUncompressedOffset, readSlice.length() - this.nextUncompressedOffset);
            this.nextUncompressedOffset = 0;
            if (readSlice.length() == 0) {
                readSlice = nextChunk();
            }
        }
        return readSlice;
    }

    private void ensureCompressedBytesAvailable(int i) throws IOException {
        if (i <= this.compressedBufferStream.remaining()) {
            return;
        }
        if (i > this.dataReader.getMaxBufferSize()) {
            throw new OrcCorruptionException(this.dataReader.getOrcDataSourceId(), "Requested read size (%s bytes) is greater than max buffer size (%s bytes)", Integer.valueOf(i), Integer.valueOf(this.dataReader.getMaxBufferSize()));
        }
        if (this.compressedBufferStart + this.compressedBufferStream.position() + i > this.dataReader.getSize()) {
            throw new OrcCorruptionException(this.dataReader.getOrcDataSourceId(), "Read past end of stream");
        }
        this.compressedBufferStart += Math.toIntExact(this.compressedBufferStream.position());
        Slice seekBuffer = this.dataReader.seekBuffer(this.compressedBufferStart);
        this.dataReaderMemoryUsage.setBytes(this.dataReader.getRetainedSize());
        if (seekBuffer.length() < i) {
            throw new OrcCorruptionException(this.dataReader.getOrcDataSourceId(), "Requested read of %s bytes but only %s were bytes", Integer.valueOf(i), Integer.valueOf(seekBuffer.length()));
        }
        this.compressedBufferStream = seekBuffer.getInput();
    }

    private OrcDecompressor.OutputBuffer createOutputBuffer() {
        return new OrcDecompressor.OutputBuffer() { // from class: io.trino.orc.stream.CompressedOrcChunkLoader.1
            @Override // io.trino.orc.OrcDecompressor.OutputBuffer
            public byte[] initialize(int i) {
                if (CompressedOrcChunkLoader.this.decompressorOutputBuffer == null || i > CompressedOrcChunkLoader.this.decompressorOutputBuffer.length) {
                    CompressedOrcChunkLoader.this.decompressorOutputBuffer = new byte[i];
                    CompressedOrcChunkLoader.this.decompressionBufferMemoryUsage.setBytes(CompressedOrcChunkLoader.this.decompressorOutputBuffer.length);
                }
                return CompressedOrcChunkLoader.this.decompressorOutputBuffer;
            }

            @Override // io.trino.orc.OrcDecompressor.OutputBuffer
            public byte[] grow(int i) {
                if (i > CompressedOrcChunkLoader.this.decompressorOutputBuffer.length) {
                    CompressedOrcChunkLoader.this.decompressorOutputBuffer = Arrays.copyOfRange(CompressedOrcChunkLoader.this.decompressorOutputBuffer, 0, i);
                    CompressedOrcChunkLoader.this.decompressionBufferMemoryUsage.setBytes(CompressedOrcChunkLoader.this.decompressorOutputBuffer.length);
                }
                return CompressedOrcChunkLoader.this.decompressorOutputBuffer;
            }
        };
    }

    public String toString() {
        return MoreObjects.toStringHelper(this).add("loader", this.dataReader).add("compressedOffset", getCurrentCompressedOffset()).add("decompressor", this.decompressor).toString();
    }
}
