From d8a6075dd61748c88733c4964ba37ed2430dc671 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 21 Nov 2024 14:30:19 -0800 Subject: [PATCH] [SPARK-50380][SQL] ReorderAssociativeOperator should respect the contract in ConstantFolding ### What changes were proposed in this pull request? This PR fixes a long-standing issue in `ReorderAssociativeOperator`. In this rule, we flatten the Add/Multiply nodes, and combine the foldable operands into a single Add/Multiply, then evaluate it into a literal. This is fine normally, but we added a new contract in `ConstantFolding` with https://github.com/apache/spark/pull/36468 , due to the introduction of ANSI mode and we don't want to fail eagerly for expressions within conditional branches. `ReorderAssociativeOperator` does not follow this contract. The solution in this PR is to leave the expression evaluation to `ConstantFolding`. `ReorderAssociativeOperator` should only match literals. This makes sure that the early expression evaluation follows all the contracts in `ConstantFolding`. ### Why are the changes needed? Avoid failing the query which should not fail. This also fixes a regression caused by https://github.com/apache/spark/pull/48395 , which does not introduce the bug, but makes the bug more likely to happen. ### Does this PR introduce _any_ user-facing change? Yes, failed queries can run now. ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #48918 from cloud-fan/error. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/optimizer/expressions.scala | 36 ++++++++++--------- .../ReorderAssociativeOperatorSuite.scala | 20 +++++++++-- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3eb7eb6e6b2e8..754fea85ec6d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -249,6 +249,11 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { case _ => ExpressionSet(Seq.empty) } + private def isSameInteger(expr: Expression, value: Int): Boolean = expr match { + case l: Literal => l.value == value + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(BINARY_ARITHMETIC), ruleId) { case q: LogicalPlan => @@ -259,32 +264,31 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { val groupingExpressionSet = collectGroupingExpressions(q) q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) { case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable) - if (foldables.nonEmpty) { - val foldableExpr = foldables.reduce((x, y) => Add(x, y, f)) - val foldableValue = foldableExpr.eval(EmptyRow) + val (literals, others) = flattenAdd(a, groupingExpressionSet) + .partition(_.isInstanceOf[Literal]) + if (literals.nonEmpty) { + val literalExpr = literals.reduce((x, y) => Add(x, y, f)) if (others.isEmpty) { - Literal.create(foldableValue, a.dataType) - } else if (foldableValue == 0) { + literalExpr + } else if (isSameInteger(literalExpr, 0)) { others.reduce((x, y) => Add(x, y, f)) } else { - Add(others.reduce((x, y) => Add(x, y, f)), Literal.create(foldableValue, a.dataType), f) + Add(others.reduce((x, y) => Add(x, y, f)), literalExpr, f) } } else { a } case m @ Multiply(_, _, f) if m.deterministic && m.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable) - if (foldables.nonEmpty) { - val foldableExpr = foldables.reduce((x, y) => Multiply(x, y, f)) - val foldableValue = foldableExpr.eval(EmptyRow) - if (others.isEmpty || (foldableValue == 0 && !m.nullable)) { - Literal.create(foldableValue, m.dataType) - } else if (foldableValue == 1) { + val (literals, others) = flattenMultiply(m, groupingExpressionSet) + .partition(_.isInstanceOf[Literal]) + if (literals.nonEmpty) { + val literalExpr = literals.reduce((x, y) => Multiply(x, y, f)) + if (others.isEmpty || (isSameInteger(literalExpr, 0) && !m.nullable)) { + literalExpr + } else if (isSameInteger(literalExpr, 1)) { others.reduce((x, y) => Multiply(x, y, f)) } else { - Multiply(others.reduce((x, y) => Multiply(x, y, f)), - Literal.create(foldableValue, m.dataType), f) + Multiply(others.reduce((x, y) => Multiply(x, y, f)), literalExpr, f) } } else { m diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala index 9090e0c7fc104..7733e58547fe0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -29,7 +29,8 @@ class ReorderAssociativeOperatorSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("ReorderAssociativeOperator", Once, + Batch("ReorderAssociativeOperator", FixedPoint(10), + ConstantFolding, ReorderAssociativeOperator) :: Nil } @@ -44,7 +45,7 @@ class ReorderAssociativeOperatorSuite extends PlanTest { ($"b" + 1) * 2 * 3 * 4, $"a" + 1 + $"b" + 2 + $"c" + 3, $"a" + 1 + $"b" * 2 + $"c" + 3, - Rand(0) * 1 * 2 * 3 * 4) + Rand(0) * 1.0 * 2.0 * 3.0 * 4.0) val optimized = Optimize.execute(originalQuery.analyze) @@ -56,7 +57,7 @@ class ReorderAssociativeOperatorSuite extends PlanTest { (($"b" + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"), ($"a" + $"b" + $"c" + 6).as("(((((a + 1) + b) + 2) + c) + 3)"), ($"a" + $"b" * 2 + $"c" + 4).as("((((a + 1) + (b * 2)) + c) + 3)"), - Rand(0) * 1 * 2 * 3 * 4) + Rand(0) * 1.0 * 2.0 * 3.0 * 4.0) .analyze comparePlans(optimized, correctAnswer) @@ -106,4 +107,17 @@ class ReorderAssociativeOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-50380: conditional branches with error expression") { + val originalQuery1 = testRelation.select(If($"a" === 1, 1L, Literal(1).div(0) + $"b")).analyze + val optimized1 = Optimize.execute(originalQuery1) + comparePlans(optimized1, originalQuery1) + + val originalQuery2 = testRelation.select( + If($"a" === 1, 1, ($"b" + Literal(Int.MaxValue)) + 1).as("col")).analyze + val optimized2 = Optimize.execute(originalQuery2) + val correctAnswer2 = testRelation.select( + If($"a" === 1, 1, $"b" + (Literal(Int.MaxValue) + 1)).as("col")).analyze + comparePlans(optimized2, correctAnswer2) + } }