diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3eb7eb6e6b2e8..754fea85ec6d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -249,6 +249,11 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { case _ => ExpressionSet(Seq.empty) } + private def isSameInteger(expr: Expression, value: Int): Boolean = expr match { + case l: Literal => l.value == value + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(BINARY_ARITHMETIC), ruleId) { case q: LogicalPlan => @@ -259,32 +264,31 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { val groupingExpressionSet = collectGroupingExpressions(q) q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) { case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable) - if (foldables.nonEmpty) { - val foldableExpr = foldables.reduce((x, y) => Add(x, y, f)) - val foldableValue = foldableExpr.eval(EmptyRow) + val (literals, others) = flattenAdd(a, groupingExpressionSet) + .partition(_.isInstanceOf[Literal]) + if (literals.nonEmpty) { + val literalExpr = literals.reduce((x, y) => Add(x, y, f)) if (others.isEmpty) { - Literal.create(foldableValue, a.dataType) - } else if (foldableValue == 0) { + literalExpr + } else if (isSameInteger(literalExpr, 0)) { others.reduce((x, y) => Add(x, y, f)) } else { - Add(others.reduce((x, y) => Add(x, y, f)), Literal.create(foldableValue, a.dataType), f) + Add(others.reduce((x, y) => Add(x, y, f)), literalExpr, f) } } else { a } case m @ Multiply(_, _, f) if m.deterministic && m.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable) - if (foldables.nonEmpty) { - val foldableExpr = foldables.reduce((x, y) => Multiply(x, y, f)) - val foldableValue = foldableExpr.eval(EmptyRow) - if (others.isEmpty || (foldableValue == 0 && !m.nullable)) { - Literal.create(foldableValue, m.dataType) - } else if (foldableValue == 1) { + val (literals, others) = flattenMultiply(m, groupingExpressionSet) + .partition(_.isInstanceOf[Literal]) + if (literals.nonEmpty) { + val literalExpr = literals.reduce((x, y) => Multiply(x, y, f)) + if (others.isEmpty || (isSameInteger(literalExpr, 0) && !m.nullable)) { + literalExpr + } else if (isSameInteger(literalExpr, 1)) { others.reduce((x, y) => Multiply(x, y, f)) } else { - Multiply(others.reduce((x, y) => Multiply(x, y, f)), - Literal.create(foldableValue, m.dataType), f) + Multiply(others.reduce((x, y) => Multiply(x, y, f)), literalExpr, f) } } else { m diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala index 9090e0c7fc104..7733e58547fe0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -29,7 +29,8 @@ class ReorderAssociativeOperatorSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("ReorderAssociativeOperator", Once, + Batch("ReorderAssociativeOperator", FixedPoint(10), + ConstantFolding, ReorderAssociativeOperator) :: Nil } @@ -44,7 +45,7 @@ class ReorderAssociativeOperatorSuite extends PlanTest { ($"b" + 1) * 2 * 3 * 4, $"a" + 1 + $"b" + 2 + $"c" + 3, $"a" + 1 + $"b" * 2 + $"c" + 3, - Rand(0) * 1 * 2 * 3 * 4) + Rand(0) * 1.0 * 2.0 * 3.0 * 4.0) val optimized = Optimize.execute(originalQuery.analyze) @@ -56,7 +57,7 @@ class ReorderAssociativeOperatorSuite extends PlanTest { (($"b" + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"), ($"a" + $"b" + $"c" + 6).as("(((((a + 1) + b) + 2) + c) + 3)"), ($"a" + $"b" * 2 + $"c" + 4).as("((((a + 1) + (b * 2)) + c) + 3)"), - Rand(0) * 1 * 2 * 3 * 4) + Rand(0) * 1.0 * 2.0 * 3.0 * 4.0) .analyze comparePlans(optimized, correctAnswer) @@ -106,4 +107,17 @@ class ReorderAssociativeOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-50380: conditional branches with error expression") { + val originalQuery1 = testRelation.select(If($"a" === 1, 1L, Literal(1).div(0) + $"b")).analyze + val optimized1 = Optimize.execute(originalQuery1) + comparePlans(optimized1, originalQuery1) + + val originalQuery2 = testRelation.select( + If($"a" === 1, 1, ($"b" + Literal(Int.MaxValue)) + 1).as("col")).analyze + val optimized2 = Optimize.execute(originalQuery2) + val correctAnswer2 = testRelation.select( + If($"a" === 1, 1, $"b" + (Literal(Int.MaxValue) + 1)).as("col")).analyze + comparePlans(optimized2, correctAnswer2) + } }