/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import io.trino.hive.$internal.org.apache.commons.logging.Log;
import io.trino.hive.$internal.org.apache.commons.logging.LogFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.FilterOperator;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.ForwardWalker;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.PreOrderOnceWalker;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TypeRule;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFStruct;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;

public class PartitionColumnsSeparator
extends Transform {
    private static final Log LOG = LogFactory.getLog(PartitionColumnsSeparator.class);
    private static final String IN_UDF = GenericUDFIn.class.getAnnotation(Description.class).name();
    private static final String STRUCT_UDF = GenericUDFStruct.class.getAnnotation(Description.class).name();
    private static final String AND_UDF = GenericUDFOPAnd.class.getAnnotation(Description.class).name();

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", FilterOperator.getOperatorName() + "%"), new StructInTransformer());
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, opRules, null);
        ForwardWalker ogw = new ForwardWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private class StructInExprProcessor
    implements NodeProcessor {
        private Map<ExprNodeDesc, Boolean> exprNodeToPartOrVirtualColOrConstExpr = new IdentityHashMap<ExprNodeDesc, Boolean>();

        private StructInExprProcessor() {
        }

        private boolean exprContainsOnlyPartitionColOrVirtualColOrConstants(ExprNodeDesc en) {
            if (en == null) {
                return true;
            }
            if (this.exprNodeToPartOrVirtualColOrConstExpr.containsKey(en)) {
                return this.exprNodeToPartOrVirtualColOrConstExpr.get(en);
            }
            if (en instanceof ExprNodeColumnDesc) {
                boolean ret = ((ExprNodeColumnDesc)en).getIsPartitionColOrVirtualCol();
                this.exprNodeToPartOrVirtualColOrConstExpr.put(en, ret);
                return ret;
            }
            if (en.getChildren() != null) {
                for (ExprNodeDesc cn : en.getChildren()) {
                    if (this.exprContainsOnlyPartitionColOrVirtualColOrConstants(cn)) continue;
                    this.exprNodeToPartOrVirtualColOrConstExpr.put(en, false);
                    return false;
                }
            }
            this.exprNodeToPartOrVirtualColOrConstExpr.put(en, true);
            return true;
        }

        private boolean hasAtleastOneSubExprWithPartColOrVirtualColWithOneTableAlias(ExprNodeDesc en) {
            if (en == null || en.getChildren() == null) {
                return false;
            }
            for (ExprNodeDesc cn : en.getChildren()) {
                if (!this.exprContainsOnlyPartitionColOrVirtualColOrConstants(cn) || this.getTableAlias(cn) == null) continue;
                return true;
            }
            return false;
        }

        private boolean hasAllSubExprWithConstOrPartColOrVirtualColWithOneTableAlias(ExprNodeDesc en) {
            if (!this.exprContainsOnlyPartitionColOrVirtualColOrConstants(en)) {
                return false;
            }
            HashSet<String> s = new HashSet<String>();
            HashSet<ExprNodeDesc> visited = new HashSet<ExprNodeDesc>();
            return this.getTableAliasHelper(en, s, visited);
        }

        private ExprNodeGenericFuncDesc getInExprNode(ExprNodeDesc en) {
            if (en == null) {
                return null;
            }
            if (en instanceof ExprNodeGenericFuncDesc && ((ExprNodeGenericFuncDesc)en).getGenericUDF() instanceof GenericUDFIn) {
                return (ExprNodeGenericFuncDesc)en;
            }
            return null;
        }

        private boolean getTableAliasHelper(ExprNodeDesc en, Set<String> s, Set<ExprNodeDesc> visited) {
            visited.add(en);
            if (en instanceof ExprNodeColumnDesc) {
                if (s.size() > 0 && !s.contains(((ExprNodeColumnDesc)en).getTabAlias())) {
                    return false;
                }
                if (s.size() == 0) {
                    s.add(((ExprNodeColumnDesc)en).getTabAlias());
                }
                return true;
            }
            if (en.getChildren() == null) {
                return true;
            }
            for (ExprNodeDesc cn : en.getChildren()) {
                if (visited.contains(cn)) continue;
                if (cn instanceof ExprNodeColumnDesc) {
                    s.add(((ExprNodeColumnDesc)cn).getTabAlias());
                    continue;
                }
                if (cn instanceof ExprNodeConstantDesc || this.getTableAliasHelper(cn, s, visited)) continue;
                return false;
            }
            return true;
        }

        private String getTableAlias(ExprNodeDesc en) {
            HashSet<String> s = new HashSet<String>();
            HashSet<ExprNodeDesc> visited = new HashSet<ExprNodeDesc>();
            boolean singleTableAlias = this.getTableAliasHelper(en, s, visited);
            if (!singleTableAlias || s.size() == 0) {
                return null;
            }
            StringBuilder ans = new StringBuilder();
            for (String st : s) {
                ans.append(st);
            }
            return ans.toString();
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            TableInfo currTableInfo;
            ExprNodeGenericFuncDesc fd = this.getInExprNode((ExprNodeDesc)nd);
            if (fd == null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Partition columns not separated for " + fd + ", is not IN operator : ");
                }
                return null;
            }
            List<ExprNodeDesc> children = fd.getChildren();
            if (!(children.get(0) instanceof ExprNodeGenericFuncDesc) || !(((ExprNodeGenericFuncDesc)children.get(0)).getGenericUDF() instanceof GenericUDFStruct)) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Partition columns not separated for " + fd + ", children size " + children.size() + ", child expression : " + children.get(0).getExprString());
                }
                return null;
            }
            if (!this.hasAtleastOneSubExprWithPartColOrVirtualColWithOneTableAlias(children.get(0))) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Partition columns not separated for " + fd + ", there are no expression containing partition columns in struct fields");
                }
                return null;
            }
            if (this.hasAllSubExprWithConstOrPartColOrVirtualColWithOneTableAlias(children.get(0))) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Partition columns not separated for " + fd + ", all fields are expressions containing constants or only partition columnscoming from same table");
                }
                return null;
            }
            HashMap<String, TableInfo> tableAliasToInfo = new HashMap<String, TableInfo>();
            ExprNodeGenericFuncDesc originalStructDesc = (ExprNodeGenericFuncDesc)children.get(0);
            List<ExprNodeDesc> originalDescChildren = originalStructDesc.getChildren();
            for (int i = 0; i < originalDescChildren.size(); ++i) {
                ExprNodeDesc en = originalDescChildren.get(i);
                String tabAlias = null;
                if (!this.exprContainsOnlyPartitionColOrVirtualColOrConstants(en) || (tabAlias = this.getTableAlias(en)) == null) continue;
                currTableInfo = null;
                currTableInfo = tableAliasToInfo.containsKey(tabAlias) ? (TableInfo)tableAliasToInfo.get(tabAlias) : new TableInfo();
                currTableInfo.exprNodeLHSDescriptor.add(en);
                for (int j = 1; j < children.size(); ++j) {
                    ExprNodeDesc currChildStructExpr = children.get(j);
                    ExprNodeConstantDesc newConstStructElement = null;
                    if (currChildStructExpr instanceof ExprNodeConstantDesc) {
                        List cnCols = (List)((ExprNodeConstantDesc)children.get(j)).getValue();
                        newConstStructElement = new ExprNodeConstantDesc(cnCols.get(i));
                    } else {
                        List<ExprNodeDesc> cnChildren = ((ExprNodeGenericFuncDesc)children.get(j)).getChildren();
                        newConstStructElement = new ExprNodeConstantDesc(((ExprNodeConstantDesc)cnChildren.get(i)).getValue());
                    }
                    if (currTableInfo.exprNodeRHSStructs.size() < j) {
                        ArrayList<ExprNodeConstantDesc> newConstStructList = new ArrayList<ExprNodeConstantDesc>();
                        newConstStructList.add(newConstStructElement);
                        currTableInfo.exprNodeRHSStructs.add(newConstStructList);
                        continue;
                    }
                    currTableInfo.exprNodeRHSStructs.get(j - 1).add(newConstStructElement);
                }
                if (tableAliasToInfo.containsKey(tabAlias)) continue;
                tableAliasToInfo.put(tabAlias, currTableInfo);
            }
            ArrayList<ExprNodeDesc> subExpr = new ArrayList<ExprNodeDesc>(originalDescChildren.size() + 1);
            for (Map.Entry entry : tableAliasToInfo.entrySet()) {
                currTableInfo = (TableInfo)entry.getValue();
                List<List<ExprNodeDesc>> currConstStructList = currTableInfo.exprNodeRHSStructs;
                ArrayList<ExprNodeDesc> currInStructExprList = new ArrayList<ExprNodeDesc>();
                currInStructExprList.add(ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo(STRUCT_UDF).getGenericUDF(), STRUCT_UDF, currTableInfo.exprNodeLHSDescriptor));
                for (int i = 0; i < currConstStructList.size(); ++i) {
                    List<ExprNodeDesc> currConstStruct = currConstStructList.get(i);
                    currInStructExprList.add(ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo(STRUCT_UDF).getGenericUDF(), STRUCT_UDF, currConstStruct));
                }
                subExpr.add(new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, FunctionRegistry.getFunctionInfo(IN_UDF).getGenericUDF(), currInStructExprList));
            }
            if (subExpr.size() == 1) {
                return subExpr.get(0);
            }
            return new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, FunctionRegistry.getFunctionInfo(AND_UDF).getGenericUDF(), subExpr);
        }

        class TableInfo {
            List<ExprNodeDesc> exprNodeLHSDescriptor = new ArrayList<ExprNodeDesc>();
            List<List<ExprNodeDesc>> exprNodeRHSStructs = new ArrayList<List<ExprNodeDesc>>();
        }
    }

    private class StructInTransformer
    implements NodeProcessor {
        private StructInTransformer() {
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            FilterOperator filterOp = (FilterOperator)nd;
            ExprNodeDesc predicate = ((FilterDesc)filterOp.getConf()).getPredicate();
            ExprNodeDesc newPredicate = this.generateInClauses(predicate);
            if (newPredicate != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Generated new predicate with IN clause: " + newPredicate);
                }
                ArrayList<ExprNodeDesc> subExpr = new ArrayList<ExprNodeDesc>(2);
                subExpr.add(predicate);
                subExpr.add(newPredicate);
                ExprNodeGenericFuncDesc newFilterPredicate = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, FunctionRegistry.getFunctionInfo(AND_UDF).getGenericUDF(), subExpr);
                ((FilterDesc)filterOp.getConf()).setPredicate(newFilterPredicate);
            }
            return null;
        }

        private ExprNodeDesc generateInClauses(ExprNodeDesc predicate) throws SemanticException {
            LinkedHashMap<Rule, NodeProcessor> exprRules = new LinkedHashMap<Rule, NodeProcessor>();
            exprRules.put(new TypeRule(ExprNodeGenericFuncDesc.class), new StructInExprProcessor());
            DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, exprRules, null);
            PreOrderOnceWalker egw = new PreOrderOnceWalker(disp);
            ArrayList<Node> startNodes = new ArrayList<Node>();
            startNodes.add(predicate);
            HashMap<Node, Object> outputMap = new HashMap<Node, Object>();
            egw.startWalking(startNodes, outputMap);
            return (ExprNodeDesc)outputMap.get(predicate);
        }
    }
}

