package org.apache.seatunnel.engine.server.task;

import com.hazelcast.cluster.Address;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.sink.SinkAggregatedCommitter;
import org.apache.seatunnel.engine.common.utils.ExceptionUtil;
import org.apache.seatunnel.engine.core.dag.actions.SinkAction;
import org.apache.seatunnel.engine.server.checkpoint.ActionStateKey;
import org.apache.seatunnel.engine.server.checkpoint.ActionSubtaskState;
import org.apache.seatunnel.engine.server.checkpoint.CheckpointBarrier;
import org.apache.seatunnel.engine.server.checkpoint.CheckpointCloseReason;
import org.apache.seatunnel.engine.server.checkpoint.CheckpointException;
import org.apache.seatunnel.engine.server.checkpoint.operation.TaskAcknowledgeOperation;
import org.apache.seatunnel.engine.server.execution.ProgressState;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.task.record.Barrier;
import org.apache.seatunnel.engine.server.task.statemachine.SeaTunnelTaskState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/seatunnel/engine/server/task/SinkAggregatedCommitterTask.class */
public class SinkAggregatedCommitterTask<CommandInfoT, AggregatedCommitInfoT> extends CoordinatorTask {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) SinkAggregatedCommitterTask.class);
    private static final long serialVersionUID = 5906594537520393503L;
    private SeaTunnelTaskState currState;
    private final SinkAction<?, ?, CommandInfoT, AggregatedCommitInfoT> sink;
    private final int maxWriterSize;
    private final SinkAggregatedCommitter<CommandInfoT, AggregatedCommitInfoT> aggregatedCommitter;
    private transient Serializer<AggregatedCommitInfoT> aggregatedCommitInfoSerializer;
    private Map<Long, Address> writerAddressMap;
    private ConcurrentMap<Long, List<CommandInfoT>> commitInfoCache;
    private ConcurrentMap<Long, List<AggregatedCommitInfoT>> checkpointCommitInfoMap;
    private Map<Long, Integer> checkpointBarrierCounter;
    private CompletableFuture<Void> completableFuture;
    private volatile boolean receivedSinkWriter;

    public SinkAggregatedCommitterTask(long j, TaskLocation taskLocation, SinkAction<?, ?, CommandInfoT, AggregatedCommitInfoT> sinkAction, SinkAggregatedCommitter<CommandInfoT, AggregatedCommitInfoT> sinkAggregatedCommitter) {
        super(j, taskLocation);
        this.sink = sinkAction;
        this.aggregatedCommitter = sinkAggregatedCommitter;
        this.maxWriterSize = sinkAction.getParallelism();
        this.receivedSinkWriter = false;
    }

    @Override // org.apache.seatunnel.engine.server.task.CoordinatorTask, org.apache.seatunnel.engine.server.task.AbstractTask, org.apache.seatunnel.engine.server.execution.Task
    public void init() throws Exception {
        super.init();
        this.currState = SeaTunnelTaskState.INIT;
        this.checkpointBarrierCounter = new ConcurrentHashMap();
        this.commitInfoCache = new ConcurrentHashMap();
        this.writerAddressMap = new ConcurrentHashMap();
        this.checkpointCommitInfoMap = new ConcurrentHashMap();
        this.completableFuture = new CompletableFuture<>();
        this.aggregatedCommitInfoSerializer = this.sink.getSink().getAggregatedCommitInfoSerializer().get();
        log.debug("starting seatunnel sink aggregated committer task, sink name[{}] ", this.sink.getName());
    }

    public void receivedWriterRegister(TaskLocation taskLocation, Address address) {
        this.writerAddressMap.put(Long.valueOf(taskLocation.getTaskID()), address);
        if (this.maxWriterSize <= this.writerAddressMap.size()) {
            this.receivedSinkWriter = true;
        }
    }

    @Override // org.apache.seatunnel.engine.server.task.AbstractTask, org.apache.seatunnel.engine.server.execution.Task
    @NonNull
    public ProgressState call() throws Exception {
        stateProcess();
        return this.progress.toState();
    }

    protected void stateProcess() throws Exception {
        switch (this.currState) {
            case INIT:
                this.currState = SeaTunnelTaskState.WAITING_RESTORE;
                reportTaskStatus(SeaTunnelTaskState.WAITING_RESTORE);
                return;
            case WAITING_RESTORE:
                if (this.restoreComplete.isDone()) {
                    this.currState = SeaTunnelTaskState.READY_START;
                    reportTaskStatus(SeaTunnelTaskState.READY_START);
                    return;
                }
                return;
            case READY_START:
                if (this.startCalled) {
                    this.currState = SeaTunnelTaskState.STARTING;
                    return;
                }
                return;
            case STARTING:
                if (this.receivedSinkWriter) {
                    this.currState = SeaTunnelTaskState.RUNNING;
                    return;
                }
                return;
            case RUNNING:
                if (this.prepareCloseStatus) {
                    this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
                    return;
                } else {
                    Thread.sleep(100L);
                    return;
                }
            case PREPARE_CLOSE:
                if (this.closeCalled) {
                    this.currState = SeaTunnelTaskState.CLOSED;
                    return;
                } else {
                    Thread.sleep(100L);
                    return;
                }
            case CLOSED:
                close();
                return;
            case CANCELLING:
                close();
                this.currState = SeaTunnelTaskState.CANCELED;
                return;
            default:
                throw new IllegalArgumentException("Unknown Enumerator State: " + this.currState);
        }
    }

    @Override // org.apache.seatunnel.engine.server.task.AbstractTask, org.apache.seatunnel.engine.server.execution.Task
    public void close() throws IOException {
        super.close();
        this.aggregatedCommitter.close();
        this.progress.done();
        this.completableFuture.complete(null);
    }

    @Override // org.apache.seatunnel.engine.server.execution.Task
    public void triggerBarrier(Barrier barrier) throws Exception {
        log.debug("trigger barrier for sink agg commit [{}]", barrier);
        if (this.checkpointBarrierCounter.compute(Long.valueOf(barrier.getId()), (l, num) -> {
            return Integer.valueOf(num == null ? 1 : Integer.valueOf(num.intValue() + 1).intValue());
        }).intValue() != this.maxWriterSize) {
            return;
        }
        if (barrier.prepareClose()) {
            this.prepareCloseStatus = true;
            this.prepareCloseBarrierId.set(barrier.getId());
        }
        if (barrier.snapshot()) {
            if (this.commitInfoCache.containsKey(Long.valueOf(barrier.getId()))) {
                log.debug("commitInfoCache contains Key [{}]", Long.valueOf(barrier.getId()));
                AggregatedCommitInfoT combine = this.aggregatedCommitter.combine(this.commitInfoCache.get(Long.valueOf(barrier.getId())));
                log.debug("get the aggregatedCommitInfoT [{}]", combine);
                this.checkpointCommitInfoMap.put(Long.valueOf(barrier.getId()), Collections.singletonList(combine));
            }
            List<AggregatedCommitInfoT> orDefault = this.checkpointCommitInfoMap.getOrDefault(Long.valueOf(barrier.getId()), Collections.emptyList());
            log.debug("final store commit info size [{}]", Integer.valueOf(orDefault.size()));
            log.debug("final store commit info [{}]", orDefault);
            getExecutionContext().sendToMaster(new TaskAcknowledgeOperation(this.taskLocation, (CheckpointBarrier) barrier, Collections.singletonList(new ActionSubtaskState(ActionStateKey.of(this.sink), -1, serializeStates(this.aggregatedCommitInfoSerializer, this.checkpointCommitInfoMap.getOrDefault(Long.valueOf(barrier.getId()), Collections.emptyList())))))).join();
        }
    }

    @Override // org.apache.seatunnel.engine.server.execution.Task, org.apache.seatunnel.engine.server.checkpoint.Stateful
    public void restoreState(List<ActionSubtaskState> list) throws Exception {
        log.debug("restoreState for sink agg committer [{}]", list);
        this.aggregatedCommitter.commit((List) list.stream().map((v0) -> {
            return v0.getState();
        }).flatMap((v0) -> {
            return v0.stream();
        }).map(bArr -> {
            return ExceptionUtil.sneaky(() -> {
                return this.aggregatedCommitInfoSerializer.deserialize(bArr);
            });
        }).collect(Collectors.toList()));
        this.restoreComplete.complete(null);
        log.debug("restoreState for sink agg committer [{}] finished", list);
    }

    public void receivedWriterCommitInfo(long j, CommandInfoT commandinfot) {
        log.debug("received writer commit infos checkpoint id [{}], commitInfos [{}]", Long.valueOf(j), commandinfot);
        this.commitInfoCache.computeIfAbsent(Long.valueOf(j), l -> {
            return new CopyOnWriteArrayList();
        });
        this.commitInfoCache.get(Long.valueOf(j)).add(commandinfot);
    }

    @Override // org.apache.seatunnel.engine.server.task.AbstractTask
    public Set<URL> getJarsUrl() {
        return new HashSet(this.sink.getJarUrls());
    }

    @Override // org.apache.seatunnel.engine.core.checkpoint.InternalCheckpointListener, org.apache.seatunnel.api.state.CheckpointListener
    public void notifyCheckpointComplete(long j) throws Exception {
        ArrayList arrayList = new ArrayList();
        this.checkpointCommitInfoMap.forEach((l, list) -> {
            if (l.longValue() > j) {
                return;
            }
            arrayList.addAll(list);
            this.checkpointCommitInfoMap.remove(l);
        });
        List<AggregatedCommitInfoT> commit = this.aggregatedCommitter.commit(arrayList);
        tryClose(j);
        if (CollectionUtils.isEmpty(commit)) {
            return;
        }
        log.error("aggregated committer error: {}", Integer.valueOf(commit.size()));
        throw new CheckpointException(CheckpointCloseReason.AGGREGATE_COMMIT_ERROR);
    }

    @Override // org.apache.seatunnel.engine.core.checkpoint.InternalCheckpointListener, org.apache.seatunnel.api.state.CheckpointListener
    public void notifyCheckpointAborted(long j) throws Exception {
        this.aggregatedCommitter.abort(this.checkpointCommitInfoMap.get(Long.valueOf(j)));
        this.checkpointCommitInfoMap.remove(Long.valueOf(j));
        tryClose(j);
    }
}
