/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.tez;

import io.trino.hive.$internal.com.google.common.annotations.VisibleForTesting;
import io.trino.hive.$internal.com.google.common.base.Preconditions;
import io.trino.hive.$internal.org.apache.commons.lang3.mutable.MutableInt;
import io.trino.hive.$internal.org.slf4j.Logger;
import io.trino.hive.$internal.org.slf4j.LoggerFactory;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveWritableObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputInitializerEvent;

public class DynamicPartitionPruner {
    private static final Logger LOG = LoggerFactory.getLogger(DynamicPartitionPruner.class);
    private final InputInitializerContext context;
    private final MapWork work;
    private final JobConf jobConf;
    private final Map<String, List<SourceInfo>> sourceInfoMap = new HashMap<String, List<SourceInfo>>();
    private final BytesWritable writable = new BytesWritable();
    private final BlockingQueue<Object> queue = new LinkedBlockingQueue<Object>();
    private final Set<String> sourcesWaitingForEvents = new HashSet<String>();
    private final Map<String, MutableInt> numExpectedEventsPerSource = new HashMap<String, MutableInt>();
    private final Map<String, MutableInt> numEventsSeenPerSource = new HashMap<String, MutableInt>();
    private int sourceInfoCount = 0;
    private final Object endOfEvents = new Object();
    private int totalEventCount = 0;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public DynamicPartitionPruner(InputInitializerContext context, MapWork work, JobConf jobConf) throws SerDeException {
        this.context = context;
        this.work = work;
        this.jobConf = jobConf;
        DynamicPartitionPruner dynamicPartitionPruner = this;
        synchronized (dynamicPartitionPruner) {
            this.initialize();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void prune() throws SerDeException, IOException, InterruptedException, HiveException {
        Set<String> set = this.sourcesWaitingForEvents;
        synchronized (set) {
            if (this.sourcesWaitingForEvents.isEmpty()) {
                return;
            }
            Set<VertexState> states = Collections.singleton(VertexState.SUCCEEDED);
            for (String source : this.sourcesWaitingForEvents) {
                this.context.registerForVertexStateUpdates(source, states);
            }
        }
        LOG.info("Waiting for events (" + this.sourceInfoCount + " sources) ...");
        this.processEvents();
        this.prunePartitions();
        LOG.info("Ok to proceed.");
    }

    public BlockingQueue<Object> getQueue() {
        return this.queue;
    }

    private void clear() {
        this.sourceInfoMap.clear();
        this.sourceInfoCount = 0;
    }

    private void initialize() throws SerDeException {
        this.clear();
        HashMap<String, SourceInfo> columnMap = new HashMap<String, SourceInfo>();
        Set<String> sources = this.work.getEventSourceTableDescMap().keySet();
        this.sourcesWaitingForEvents.addAll(sources);
        for (String s : sources) {
            this.numExpectedEventsPerSource.put(s, new MutableInt(0));
            this.numEventsSeenPerSource.put(s, new MutableInt(0));
            List<TableDesc> tables = this.work.getEventSourceTableDescMap().get(s);
            List<String> columnNames = this.work.getEventSourceColumnNameMap().get(s);
            List<String> columnTypes = this.work.getEventSourceColumnTypeMap().get(s);
            List<ExprNodeDesc> partKeyExprs = this.work.getEventSourcePartKeyExprMap().get(s);
            Iterator<String> cit = columnNames.iterator();
            Iterator<String> typit = columnTypes.iterator();
            Iterator<ExprNodeDesc> pit = partKeyExprs.iterator();
            for (TableDesc t : tables) {
                this.numExpectedEventsPerSource.get(s).decrement();
                ++this.sourceInfoCount;
                String columnName = cit.next();
                String columnType = typit.next();
                ExprNodeDesc partKeyExpr = pit.next();
                SourceInfo si = this.createSourceInfo(t, partKeyExpr, columnName, columnType, this.jobConf);
                if (!this.sourceInfoMap.containsKey(s)) {
                    this.sourceInfoMap.put(s, new ArrayList());
                }
                List<SourceInfo> sis = this.sourceInfoMap.get(s);
                sis.add(si);
                if (columnMap.containsKey(columnName)) {
                    si.values = ((SourceInfo)columnMap.get((Object)columnName)).values;
                    si.skipPruning = ((SourceInfo)columnMap.get((Object)columnName)).skipPruning;
                }
                columnMap.put(columnName, si);
            }
        }
    }

    private void prunePartitions() throws HiveException {
        int expectedEvents = 0;
        for (Map.Entry<String, List<SourceInfo>> entry : this.sourceInfoMap.entrySet()) {
            String source = entry.getKey();
            for (SourceInfo si : entry.getValue()) {
                int taskNum = this.context.getVertexNumTasks(source);
                LOG.info("Expecting " + taskNum + " events for vertex " + source + ", for column " + si.columnName);
                expectedEvents += taskNum;
                this.prunePartitionSingleSource(source, si);
            }
        }
        if (expectedEvents != this.totalEventCount) {
            LOG.error("Expecting: " + expectedEvents + ", received: " + this.totalEventCount);
            throw new HiveException("Incorrect event count in dynamic partition pruning");
        }
    }

    @VisibleForTesting
    protected void prunePartitionSingleSource(String source, SourceInfo si) throws HiveException {
        if (si.skipPruning.get()) {
            LOG.info("Skip pruning on " + source + ", column " + si.columnName);
            return;
        }
        Set<Object> values = si.values;
        String columnName = si.columnName;
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("Pruning ");
            sb.append(columnName);
            sb.append(" with ");
            for (Object value : values) {
                sb.append(value == null ? null : value.toString());
                sb.append(", ");
            }
            LOG.debug(sb.toString());
        }
        AbstractPrimitiveWritableObjectInspector oi = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getPrimitiveTypeInfo(si.columnType));
        ObjectInspectorConverters.Converter converter = ObjectInspectorConverters.getConverter((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector, (ObjectInspector)oi);
        StandardStructObjectInspector soi = ObjectInspectorFactory.getStandardStructObjectInspector(Collections.singletonList(columnName), Collections.singletonList(oi));
        ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(si.partKey);
        eval.initialize(soi);
        this.applyFilterToPartitions(converter, eval, columnName, values);
    }

    private void applyFilterToPartitions(ObjectInspectorConverters.Converter converter, ExprNodeEvaluator eval, String columnName, Set<Object> values) throws HiveException {
        Object[] row = new Object[1];
        Iterator<Path> it = this.work.getPathToPartitionInfo().keySet().iterator();
        while (it.hasNext()) {
            Path p = it.next();
            PartitionDesc desc = this.work.getPathToPartitionInfo().get(p);
            LinkedHashMap<String, String> spec = desc.getPartSpec();
            if (spec == null) {
                throw new IllegalStateException("No partition spec found in dynamic pruning");
            }
            String partValueString = (String)spec.get(columnName);
            if (partValueString == null) {
                throw new IllegalStateException("Could not find partition value for column: " + columnName);
            }
            Object partValue = converter.convert(partValueString);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Converted partition value: " + partValue + " original (" + partValueString + ")");
            }
            row[0] = partValue;
            partValue = eval.evaluate(row);
            if (LOG.isDebugEnabled()) {
                LOG.debug("part key expr applied: " + partValue);
            }
            if (values.contains(partValue)) continue;
            LOG.info("Pruning path: " + p);
            it.remove();
            this.work.removePathToAlias(p);
        }
    }

    @VisibleForTesting
    protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName, String columnType, JobConf jobConf) throws SerDeException {
        return new SourceInfo(t, partKeyExpr, columnName, columnType, jobConf);
    }

    private void processEvents() throws SerDeException, IOException, InterruptedException {
        Object element;
        int eventCount = 0;
        while ((element = this.queue.take()) != this.endOfEvents) {
            InputInitializerEvent event = (InputInitializerEvent)element;
            LOG.info("Input event: " + event.getTargetInputName() + ", " + event.getTargetVertexName() + ", " + (event.getUserPayload().limit() - event.getUserPayload().position()));
            this.processPayload(event.getUserPayload(), event.getSourceVertexName());
            ++eventCount;
        }
        LOG.info("Received events: " + eventCount);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, IOException {
        try (DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload));){
            String columnName = in.readUTF();
            LOG.info("Source of event: " + sourceName);
            List<SourceInfo> infos = this.sourceInfoMap.get(sourceName);
            if (infos == null) {
                throw new IllegalStateException("no source info for event source: " + sourceName);
            }
            SourceInfo info = null;
            for (SourceInfo si : infos) {
                if (!columnName.equals(si.columnName)) continue;
                info = si;
                break;
            }
            if (info == null) {
                throw new IllegalStateException("no source info for column: " + columnName);
            }
            if (info.skipPruning.get()) {
            } else {
                boolean skip = in.readBoolean();
                if (skip) {
                    info.skipPruning.set(true);
                } else {
                    while (payload.hasRemaining()) {
                        this.writable.readFields((DataInput)in);
                        Object row = info.deserializer.deserialize((Writable)this.writable);
                        Object value = info.soi.getStructFieldData(row, info.field);
                        value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector);
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Adding: " + value + " to list of required partitions");
                        }
                        info.values.add(value);
                    }
                }
            }
        }
        return sourceName;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void addEvent(InputInitializerEvent event) {
        Set<String> set = this.sourcesWaitingForEvents;
        synchronized (set) {
            if (this.sourcesWaitingForEvents.contains(event.getSourceVertexName())) {
                ++this.totalEventCount;
                this.numEventsSeenPerSource.get(event.getSourceVertexName()).increment();
                if (!this.queue.offer(event)) {
                    throw new IllegalStateException("Queue full");
                }
                this.checkForSourceCompletion(event.getSourceVertexName());
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void processVertex(String name) {
        LOG.info("Vertex succeeded: " + name);
        Set<String> set = this.sourcesWaitingForEvents;
        synchronized (set) {
            MutableInt prevVal = this.numExpectedEventsPerSource.get(name);
            int prevValInt = prevVal.intValue();
            Preconditions.checkState(prevValInt < 0, "Invalid value for numExpectedEvents for source: " + name + ", oldVal=" + prevValInt);
            prevVal.setValue(-1 * prevValInt * this.context.getVertexNumTasks(name));
            this.checkForSourceCompletion(name);
        }
    }

    private void checkForSourceCompletion(String name) {
        int expectedEvents = this.numExpectedEventsPerSource.get(name).getValue();
        if (expectedEvents < 0) {
            return;
        }
        int processedEvents = this.numEventsSeenPerSource.get(name).getValue();
        if (processedEvents == expectedEvents) {
            this.sourcesWaitingForEvents.remove(name);
            if (this.sourcesWaitingForEvents.isEmpty()) {
                if (!this.queue.offer(this.endOfEvents)) {
                    throw new IllegalStateException("Queue full");
                }
            } else {
                LOG.info("Waiting for " + this.sourcesWaitingForEvents.size() + " sources.");
            }
        } else if (processedEvents > expectedEvents) {
            throw new IllegalStateException("Received too many events for " + name + ", Expected=" + expectedEvents + ", Received=" + processedEvents);
        }
    }

    private static class ByteBufferBackedInputStream
    extends InputStream {
        ByteBuffer buf;

        public ByteBufferBackedInputStream(ByteBuffer buf) {
            this.buf = buf;
        }

        @Override
        public int read() throws IOException {
            if (!this.buf.hasRemaining()) {
                return -1;
            }
            return this.buf.get() & 0xFF;
        }

        @Override
        public int read(byte[] bytes, int off, int len) throws IOException {
            if (!this.buf.hasRemaining()) {
                return -1;
            }
            len = Math.min(len, this.buf.remaining());
            this.buf.get(bytes, off, len);
            return len;
        }
    }

    @VisibleForTesting
    static class SourceInfo {
        public final ExprNodeDesc partKey;
        public final Deserializer deserializer;
        public final StructObjectInspector soi;
        public final StructField field;
        public final ObjectInspector fieldInspector;
        public Set<Object> values = new HashSet<Object>();
        public AtomicBoolean skipPruning = new AtomicBoolean();
        public final String columnName;
        public final String columnType;

        @VisibleForTesting
        SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, String columnType, JobConf jobConf, Object forTesting) {
            this.partKey = partKey;
            this.columnName = columnName;
            this.columnType = columnType;
            this.deserializer = null;
            this.soi = null;
            this.field = null;
            this.fieldInspector = null;
        }

        public SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, String columnType, JobConf jobConf) throws SerDeException {
            this.skipPruning.set(false);
            this.partKey = partKey;
            this.columnName = columnName;
            this.columnType = columnType;
            this.deserializer = (Deserializer)ReflectionUtils.newInstance(table.getDeserializerClass(), null);
            this.deserializer.initialize((Configuration)jobConf, table.getProperties());
            ObjectInspector inspector = this.deserializer.getObjectInspector();
            LOG.debug("Type of obj insp: " + inspector.getTypeName());
            this.soi = (StructObjectInspector)inspector;
            List<? extends StructField> fields = this.soi.getAllStructFieldRefs();
            if (fields.size() > 1) {
                LOG.error("expecting single field in input");
            }
            this.field = fields.get(0);
            this.fieldInspector = ObjectInspectorUtils.getStandardObjectInspector(this.field.getFieldObjectInspector());
        }
    }
}

