package io.trino.filesystem.gcs;

import com.google.cloud.WriteChannel;
import com.google.cloud.storage.Blob;
import com.google.cloud.storage.Storage;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.Objects;

/* loaded from: input_file:io/trino/filesystem/gcs/GcsOutputStream.class */
public class GcsOutputStream extends OutputStream {
    private static final int BUFFER_SIZE = 8192;
    private final GcsLocation location;
    private final long writeBlockSizeBytes;
    private final LocalMemoryContext memoryContext;
    private final WriteChannel writeChannel;
    private final ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
    private long writtenBytes;
    private boolean closed;

    public GcsOutputStream(GcsLocation gcsLocation, Blob blob, AggregatedMemoryContext aggregatedMemoryContext, long j) {
        this.location = (GcsLocation) Objects.requireNonNull(gcsLocation, "location is null");
        Preconditions.checkArgument(j >= 0, "writeBlockSizeBytes is negative");
        this.writeBlockSizeBytes = j;
        this.memoryContext = aggregatedMemoryContext.newLocalMemoryContext(GcsOutputStream.class.getSimpleName());
        this.writeChannel = blob.writer(new Storage.BlobWriteOption[0]);
        this.writeChannel.setChunkSize(Ints.saturatedCast(j));
    }

    @Override // java.io.OutputStream
    public void write(int i) throws IOException {
        ensureOpen();
        if (!this.buffer.hasRemaining()) {
            flush();
        }
        this.buffer.put((byte) i);
        recordBytesWritten(1);
    }

    @Override // java.io.OutputStream
    public void write(byte[] bArr, int i, int i2) throws IOException {
        ensureOpen();
        if (i2 > BUFFER_SIZE) {
            writeDirect(ByteBuffer.wrap(bArr, i, i2));
            return;
        }
        if (i2 > this.buffer.remaining()) {
            flush();
        }
        this.buffer.put(bArr, i, i2);
        recordBytesWritten(i2);
    }

    private void writeDirect(ByteBuffer byteBuffer) throws IOException {
        flush();
        try {
            int write = this.writeChannel.write(byteBuffer);
            if (write != byteBuffer.remaining()) {
                throw new IOException("Unexpected bytes written length: %s should be %s".formatted(Integer.valueOf(write), Integer.valueOf(byteBuffer.remaining())));
            }
            recordBytesWritten(write);
        } catch (IOException e) {
            throw new IOException("Error writing file: " + this.location, e);
        }
    }

    private void ensureOpen() throws IOException {
        if (this.closed) {
            throw new IOException("Output stream closed: " + this.location);
        }
    }

    @Override // java.io.OutputStream, java.io.Flushable
    public void flush() throws IOException {
        ensureOpen();
        if (this.buffer.position() > 0) {
            this.buffer.flip();
            while (this.buffer.hasRemaining()) {
                try {
                    this.writeChannel.write(this.buffer);
                } catch (IOException e) {
                    throw new IOException("Error writing file: " + this.location, e);
                }
            }
            this.buffer.clear();
        }
    }

    @Override // java.io.OutputStream, java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        if (this.closed) {
            return;
        }
        flush();
        this.closed = true;
        try {
            try {
                this.writeChannel.close();
                this.memoryContext.close();
            } catch (IOException e) {
                throw new IOException("Error closing file: " + this.location, e);
            }
        } catch (Throwable th) {
            this.memoryContext.close();
            throw th;
        }
    }

    private void recordBytesWritten(int i) {
        if (this.writtenBytes < this.writeBlockSizeBytes) {
            this.memoryContext.setBytes(8192 + Math.min(this.writtenBytes + i, this.writeBlockSizeBytes));
        }
        this.writtenBytes += i;
    }
}
