package org.apache.shardingsphere.sharding.route.engine.validator.dml.impl;

import java.util.List;
import java.util.Optional;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.infra.metadata.model.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.route.engine.validator.dml.ShardingDMLStatementValidator;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;

/* loaded from: input_file:org/apache/shardingsphere/sharding/route/engine/validator/dml/impl/ShardingUpdateStatementValidator.class */
public final class ShardingUpdateStatementValidator extends ShardingDMLStatementValidator<UpdateStatement> {
    @Override // org.apache.shardingsphere.sharding.route.engine.validator.ShardingStatementValidator
    public void preValidate(ShardingRule shardingRule, SQLStatementContext<UpdateStatement> sQLStatementContext, List<Object> list, ShardingSphereMetaData shardingSphereMetaData) {
        validateShardingMultipleTable(shardingRule, sQLStatementContext);
        UpdateStatement sqlStatement = sQLStatementContext.getSqlStatement();
        String value = ((SimpleTableSegment) sQLStatementContext.getTablesContext().getTables().iterator().next()).getTableName().getIdentifier().getValue();
        for (AssignmentSegment assignmentSegment : sqlStatement.getSetAssignment().getAssignments()) {
            String value2 = assignmentSegment.getColumn().getIdentifier().getValue();
            if (shardingRule.isShardingColumn(value2, value)) {
                Optional<Object> shardingColumnSetAssignmentValue = getShardingColumnSetAssignmentValue(assignmentSegment, list);
                Optional<Object> empty = Optional.empty();
                Optional where = sqlStatement.getWhere();
                if (where.isPresent()) {
                    empty = getShardingValue((WhereSegment) where.get(), list, value2);
                }
                if (!shardingColumnSetAssignmentValue.isPresent() || !empty.isPresent() || !shardingColumnSetAssignmentValue.get().equals(empty.get())) {
                    throw new ShardingSphereException("Can not update sharding key, logic table: [%s], column: [%s].", new Object[]{value, value2});
                }
            }
        }
    }

    private Optional<Object> getShardingColumnSetAssignmentValue(AssignmentSegment assignmentSegment, List<Object> list) {
        ParameterMarkerExpressionSegment value = assignmentSegment.getValue();
        int i = -1;
        if (value instanceof ParameterMarkerExpressionSegment) {
            i = value.getParameterMarkerIndex();
        }
        return value instanceof LiteralExpressionSegment ? Optional.of(((LiteralExpressionSegment) value).getLiterals()) : (-1 == i || i > list.size() - 1) ? Optional.empty() : Optional.of(list.get(i));
    }

    private Optional<Object> getShardingValue(WhereSegment whereSegment, List<Object> list, String str) {
        return null != whereSegment ? getShardingValue(whereSegment.getExpr(), list, str) : Optional.empty();
    }

    private Optional<Object> getShardingValue(ExpressionSegment expressionSegment, List<Object> list, String str) {
        if ((expressionSegment instanceof InExpression) && (((InExpression) expressionSegment).getLeft() instanceof ColumnSegment) && !str.equalsIgnoreCase(((InExpression) expressionSegment).getLeft().getIdentifier().getValue())) {
            return getPredicateInShardingValue(((InExpression) expressionSegment).getRight(), list);
        }
        if (!(expressionSegment instanceof BinaryOperationExpression)) {
            return Optional.empty();
        }
        String operator = ((BinaryOperationExpression) expressionSegment).getOperator();
        if ((">".equalsIgnoreCase(operator) || ">=".equalsIgnoreCase(operator) || "=".equalsIgnoreCase(operator) || "<".equalsIgnoreCase(operator) || "<=".equalsIgnoreCase(operator)) && (((BinaryOperationExpression) expressionSegment).getLeft() instanceof ColumnSegment) && str.equalsIgnoreCase(((BinaryOperationExpression) expressionSegment).getLeft().getIdentifier().getValue())) {
            return getPredicateCompareShardingValue(((BinaryOperationExpression) expressionSegment).getRight(), list);
        }
        if (!("and".equalsIgnoreCase(operator) || "&&".equalsIgnoreCase(operator) || "OR".equalsIgnoreCase(operator) || "||".equalsIgnoreCase(operator))) {
            return Optional.empty();
        }
        Optional<Object> shardingValue = getShardingValue(((BinaryOperationExpression) expressionSegment).getLeft(), list, str);
        return shardingValue.isPresent() ? shardingValue : getShardingValue(((BinaryOperationExpression) expressionSegment).getRight(), list, str);
    }

    private Optional<Object> getPredicateCompareShardingValue(ExpressionSegment expressionSegment, List<Object> list) {
        if (!(expressionSegment instanceof ParameterMarkerExpressionSegment)) {
            return expressionSegment instanceof LiteralExpressionSegment ? Optional.of(((LiteralExpressionSegment) expressionSegment).getLiterals()) : Optional.empty();
        }
        int parameterMarkerIndex = ((ParameterMarkerExpressionSegment) expressionSegment).getParameterMarkerIndex();
        return (-1 == parameterMarkerIndex || parameterMarkerIndex > list.size() - 1) ? Optional.empty() : Optional.of(list.get(parameterMarkerIndex));
    }

    private Optional<Object> getPredicateInShardingValue(ExpressionSegment expressionSegment, List<Object> list) {
        if (!(expressionSegment instanceof ListExpression)) {
            return Optional.empty();
        }
        for (LiteralExpressionSegment literalExpressionSegment : ((ListExpression) expressionSegment).getItems()) {
            if (literalExpressionSegment instanceof ParameterMarkerExpressionSegment) {
                int parameterMarkerIndex = ((ParameterMarkerExpressionSegment) literalExpressionSegment).getParameterMarkerIndex();
                if (-1 != parameterMarkerIndex && parameterMarkerIndex <= list.size() - 1) {
                    return Optional.of(list.get(parameterMarkerIndex));
                }
            } else if (literalExpressionSegment instanceof LiteralExpressionSegment) {
                return Optional.of(literalExpressionSegment.getLiterals());
            }
        }
        return Optional.empty();
    }

    @Override // org.apache.shardingsphere.sharding.route.engine.validator.ShardingStatementValidator
    public void postValidate(UpdateStatement updateStatement, RouteContext routeContext) {
        if (UpdateStatementHandler.getLimitSegment(updateStatement).isPresent() && routeContext.getRouteUnits().size() > 1) {
            throw new ShardingSphereException("UPDATE ... LIMIT can not support sharding route to multiple data nodes.", new Object[0]);
        }
    }
}
