package com.ontotext.trree.query.optimization;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.eclipse.rdf4j.query.algebra.BinaryTupleOperator;
import org.eclipse.rdf4j.query.algebra.Filter;
import org.eclipse.rdf4j.query.algebra.Join;
import org.eclipse.rdf4j.query.algebra.Projection;
import org.eclipse.rdf4j.query.algebra.ProjectionElem;
import org.eclipse.rdf4j.query.algebra.ProjectionElemList;
import org.eclipse.rdf4j.query.algebra.QueryModelNode;
import org.eclipse.rdf4j.query.algebra.StatementPattern;
import org.eclipse.rdf4j.query.algebra.TupleExpr;
import org.eclipse.rdf4j.query.algebra.Union;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractQueryModelVisitor;
import org.eclipse.rdf4j.query.algebra.helpers.StatementPatternCollector;
import org.eclipse.rdf4j.query.algebra.helpers.VarNameCollector;

/* loaded from: input_file:com/ontotext/trree/query/optimization/QueryUnionVisitor.class */
public class QueryUnionVisitor extends AbstractQueryModelVisitor<QueryOptimizerException> {
    private boolean isInsideFilter = false;
    private Filter filter;

    public void meet(Filter filter) throws QueryOptimizerException {
        if ((filter.getArg() instanceof Join) && (filter.getParentNode() instanceof Projection) && !this.isInsideFilter) {
            this.isInsideFilter = true;
            this.filter = filter;
        } else {
            this.filter = null;
        }
        super.meet(filter);
    }

    public void meet(Union union) throws QueryOptimizerException {
        if (isOptimizableUnion(union) && isRegularUnionJoin(union)) {
            Set<String> fetchFilterVars = fetchFilterVars();
            if (Collections.disjoint(union.getBindingNames(), fetchFilterVars)) {
                optimizeUnion(fetchFilterVars, union);
            }
        }
        this.filter = null;
        this.isInsideFilter = false;
        super.meet(union);
    }

    public void meet(Projection projection) throws QueryOptimizerException {
        this.filter = null;
        this.isInsideFilter = false;
        super.meet(projection);
    }

    private void optimizeUnion(Set<String> set, Union union) throws QueryOptimizerException {
        try {
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            List<StatementPattern> fetchMatchingStatementPatterns = fetchMatchingStatementPatterns(set, linkedHashSet);
            checkDisjointPatternSets(fetchMatchingStatementPatterns);
            Projection generateFilterProjection = generateFilterProjection(fetchMatchingStatementPatterns, linkedHashSet);
            replaceOriginalFilter(fetchMatchingStatementPatterns);
            replaceOriginalUnionParent(union, generateFilterProjection, fetchMatchingStatementPatterns);
        } catch (QueryOptimizerException e) {
            throw e;
        } catch (RuntimeException e2) {
        } catch (Exception e3) {
            throw new QueryOptimizerException(e3.getMessage());
        }
    }

    private void checkDisjointPatternSets(List<StatementPattern> list) {
        boolean z;
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        do {
            z = false;
            for (StatementPattern statementPattern : list) {
                if (!linkedHashSet2.contains(statementPattern)) {
                    Set set = (Set) statementPattern.getBindingNames().stream().filter(str -> {
                        return !str.startsWith("_const_");
                    }).collect(Collectors.toSet());
                    if (linkedHashSet.isEmpty()) {
                        linkedHashSet.addAll(set);
                        linkedHashSet2.add(statementPattern);
                        z = true;
                    } else if (!Collections.disjoint(linkedHashSet, set)) {
                        linkedHashSet.addAll(set);
                        linkedHashSet2.add(statementPattern);
                        z = true;
                    }
                }
            }
        } while (z);
        if (linkedHashSet2.size() != list.size()) {
            throw new RuntimeException("disjoint set of patterns detecte! skip Union optimisation!");
        }
    }

    private boolean isOptimizableUnion(Union union) {
        if (!this.isInsideFilter || this.filter == null) {
            return false;
        }
        QueryModelNode parentNode = union.getParentNode();
        if (!(parentNode instanceof Join)) {
            return false;
        }
        QueryModelNode parentNode2 = parentNode.getParentNode();
        while (true) {
            QueryModelNode queryModelNode = parentNode2;
            if (queryModelNode == this.filter) {
                return true;
            }
            if (!(queryModelNode instanceof Join) || !(((Join) queryModelNode).getRightArg() instanceof StatementPattern)) {
                return false;
            }
            parentNode2 = queryModelNode.getParentNode();
        }
    }

    private boolean isRegularUnionJoin(Union union) {
        Join parentNode = union.getParentNode();
        return (parentNode.getRightArg() instanceof Union) && !(parentNode.getLeftArg() instanceof Union);
    }

    private Set<String> fetchFilterVars() {
        return VarNameCollector.process(this.filter.getCondition());
    }

    private List<StatementPattern> fetchMatchingStatementPatterns(final Set<String> set, final Set<String> set2) {
        StatementPatternCollector statementPatternCollector = new StatementPatternCollector() { // from class: com.ontotext.trree.query.optimization.QueryUnionVisitor.1
            public void meet(StatementPattern statementPattern) {
                if (Collections.disjoint(set, statementPattern.getBindingNames())) {
                    return;
                }
                if (!(statementPattern.getParentNode() instanceof Join)) {
                    throw new RuntimeException("Cannot create optimized sub-select for optional patterns");
                }
                if (!statementPattern.getSubjectVar().isConstant()) {
                    set2.add(statementPattern.getSubjectVar().getName());
                }
                if (!statementPattern.getPredicateVar().isConstant()) {
                    set2.add(statementPattern.getPredicateVar().getName());
                }
                if (!statementPattern.getObjectVar().isConstant()) {
                    set2.add(statementPattern.getObjectVar().getName());
                }
                super.meet(statementPattern);
            }
        };
        this.filter.visitChildren(statementPatternCollector);
        return statementPatternCollector.getStatementPatterns();
    }

    private Projection generateFilterProjection(List<StatementPattern> list, Set<String> set) throws QueryOptimizerException {
        Filter generateNewFilter = generateNewFilter(list);
        Projection projection = new Projection();
        projection.setArg(generateNewFilter);
        ProjectionElemList projectionElemList = new ProjectionElemList();
        for (String str : set) {
            ProjectionElem projectionElem = new ProjectionElem();
            projectionElem.setName(str);
            projectionElemList.addElement(projectionElem);
        }
        projection.setProjectionElemList(projectionElemList);
        projection.setVariableScopeChange(true);
        projection.setSubquery(true);
        return projection;
    }

    private Filter generateNewFilter(List<StatementPattern> list) throws QueryOptimizerException {
        if (list.size() == 2) {
            return new Filter(new Join(list.get(0).clone(), list.get(1).clone()), this.filter.getCondition());
        }
        if (list.size() <= 2) {
            if (list.size() == 1) {
                return new Filter(list.get(0).clone(), this.filter.getCondition());
            }
            throw new QueryOptimizerException("Missing statement patterns for filter/union optimization");
        }
        TupleExpr join = new Join(list.get(list.size() - 1).clone(), list.get(list.size() - 2).clone());
        for (int size = list.size() - 3; size >= 0; size--) {
            join = new Join(join, list.get(size).clone());
        }
        return new Filter(join, this.filter.getCondition());
    }

    private void replaceOriginalUnionParent(Union union, Projection projection, List<StatementPattern> list) {
        Join join = (Join) union.getParentNode();
        join.replaceWith(generateUnionJoin(join, projection, list));
    }

    private Join generateUnionJoin(Join join, Projection projection, List<StatementPattern> list) {
        Join join2 = new Join();
        join2.setLeftArg(projection);
        if ((join.getLeftArg() instanceof StatementPattern) && list.contains(join.getLeftArg())) {
            join2.setRightArg(join.getRightArg());
            return join2;
        }
        if (join.getLeftArg() instanceof Join) {
            removeFilterPatternsBeforeUnion(list, (Join) join.getLeftArg(), (Union) join.getRightArg());
            if ((join.getLeftArg() instanceof Join) && shouldRemoveAllPatternsBeforeUnion((Join) join.getLeftArg(), list)) {
                join2.setRightArg(join.getRightArg());
                return join2;
            }
        }
        join2.setRightArg(join.getLeftArg());
        Join join3 = new Join();
        join3.setLeftArg(join2);
        join3.setRightArg(join.getRightArg());
        return join3;
    }

    private boolean shouldRemoveAllPatternsBeforeUnion(Join join, List<StatementPattern> list) {
        return (join.getLeftArg() instanceof StatementPattern) && (join.getRightArg() instanceof StatementPattern) && list.contains(join.getLeftArg()) && list.contains(join.getRightArg());
    }

    private void replaceOriginalFilter(List<StatementPattern> list) {
        Join join = (Join) this.filter.getArg();
        this.filter.replaceWith(join);
        removeFilterPatternsAfterUnion(list, join);
    }

    private void removeFilterPatternsBeforeUnion(List<StatementPattern> list, Join join, Union union) {
        StatementPatternCollector fetchPatternMatchingCollector = fetchPatternMatchingCollector(list);
        join.visitChildren(fetchPatternMatchingCollector);
        List statementPatterns = fetchPatternMatchingCollector.getStatementPatterns();
        for (int size = statementPatterns.size() - 1; size >= 0; size--) {
            StatementPattern statementPattern = (StatementPattern) statementPatterns.get(size);
            Join join2 = (Join) statementPattern.getParentNode();
            if ((join2.getRightArg() instanceof StatementPattern) && (join2.getLeftArg() instanceof StatementPattern)) {
                handleLastStatements(join2, statementPattern, union, list);
                return;
            }
            if (join2.getRightArg() instanceof StatementPattern) {
                join2.replaceWith(join2.getLeftArg());
            } else {
                join2.replaceWith(join2.getRightArg());
            }
        }
    }

    private void handleLastStatements(Join join, StatementPattern statementPattern, Union union, List<StatementPattern> list) {
        if (join.getLeftArg() == statementPattern) {
            if (list.contains(join.getRightArg())) {
                handleLastFilterPatterns(join, union);
                return;
            } else {
                join.replaceWith(join.getRightArg());
                return;
            }
        }
        if (list.contains(join.getLeftArg())) {
            handleLastFilterPatterns(join, union);
        } else {
            join.replaceWith(join.getLeftArg());
        }
    }

    private void handleLastFilterPatterns(Join join, Union union) {
        if (join.getParentNode() instanceof BinaryTupleOperator) {
            BinaryTupleOperator parentNode = join.getParentNode();
            if (parentNode.getRightArg() == union) {
                return;
            }
            if (parentNode.getLeftArg() == join) {
                parentNode.replaceWith(parentNode.getRightArg());
            } else {
                parentNode.replaceWith(parentNode.getLeftArg());
            }
        }
    }

    private void removeFilterPatternsAfterUnion(List<StatementPattern> list, Join join) {
        Join join2 = join;
        while (true) {
            Join join3 = join2;
            if (!(join3.getLeftArg() instanceof Join) || !(join3.getRightArg() instanceof StatementPattern)) {
                return;
            }
            if (list.stream().anyMatch(statementPattern -> {
                return statementPattern.equals(join3.getRightArg());
            })) {
                join3.replaceWith(join3.getLeftArg());
            }
            join2 = (Join) join3.getLeftArg();
        }
    }

    private StatementPatternCollector fetchPatternMatchingCollector(final List<StatementPattern> list) {
        return new StatementPatternCollector() { // from class: com.ontotext.trree.query.optimization.QueryUnionVisitor.2
            public void meet(StatementPattern statementPattern) {
                if (list.contains(statementPattern)) {
                    super.meet(statementPattern);
                }
            }
        };
    }
}
