Skip to content

Commit

Permalink
[SPARK-50380][SQL] ReorderAssociativeOperator should respect the cont…
Browse files Browse the repository at this point in the history
…ract in ConstantFolding

### What changes were proposed in this pull request?

This PR fixes a long-standing issue in `ReorderAssociativeOperator`. In this rule, we flatten the Add/Multiply nodes, and combine the foldable operands into a single Add/Multiply, then evaluate it into a literal. This is fine normally, but we added a new contract in `ConstantFolding` with #36468 , due to the introduction of ANSI mode and we don't want to fail eagerly for expressions within conditional branches. `ReorderAssociativeOperator` does not follow this contract.

The solution in this PR is to leave the expression evaluation to `ConstantFolding`. `ReorderAssociativeOperator` should only match literals. This makes sure that the early expression evaluation follows all the contracts in `ConstantFolding`.

### Why are the changes needed?

Avoid failing the query which should not fail. This also fixes a regression caused by #48395 , which does not introduce the bug, but makes the bug more likely to happen.

### Does this PR introduce _any_ user-facing change?

Yes, failed queries can run now.

### How was this patch tested?

new test

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #48918 from cloud-fan/error.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
cloud-fan authored and dongjoon-hyun committed Nov 21, 2024
1 parent 2d09ef2 commit d8a6075
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
case _ => ExpressionSet(Seq.empty)
}

private def isSameInteger(expr: Expression, value: Int): Boolean = expr match {
case l: Literal => l.value == value
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(BINARY_ARITHMETIC), ruleId) {
case q: LogicalPlan =>
Expand All @@ -259,32 +264,31 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
val groupingExpressionSet = collectGroupingExpressions(q)
q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) {
case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.nonEmpty) {
val foldableExpr = foldables.reduce((x, y) => Add(x, y, f))
val foldableValue = foldableExpr.eval(EmptyRow)
val (literals, others) = flattenAdd(a, groupingExpressionSet)
.partition(_.isInstanceOf[Literal])
if (literals.nonEmpty) {
val literalExpr = literals.reduce((x, y) => Add(x, y, f))
if (others.isEmpty) {
Literal.create(foldableValue, a.dataType)
} else if (foldableValue == 0) {
literalExpr
} else if (isSameInteger(literalExpr, 0)) {
others.reduce((x, y) => Add(x, y, f))
} else {
Add(others.reduce((x, y) => Add(x, y, f)), Literal.create(foldableValue, a.dataType), f)
Add(others.reduce((x, y) => Add(x, y, f)), literalExpr, f)
}
} else {
a
}
case m @ Multiply(_, _, f) if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable)
if (foldables.nonEmpty) {
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y, f))
val foldableValue = foldableExpr.eval(EmptyRow)
if (others.isEmpty || (foldableValue == 0 && !m.nullable)) {
Literal.create(foldableValue, m.dataType)
} else if (foldableValue == 1) {
val (literals, others) = flattenMultiply(m, groupingExpressionSet)
.partition(_.isInstanceOf[Literal])
if (literals.nonEmpty) {
val literalExpr = literals.reduce((x, y) => Multiply(x, y, f))
if (others.isEmpty || (isSameInteger(literalExpr, 0) && !m.nullable)) {
literalExpr
} else if (isSameInteger(literalExpr, 1)) {
others.reduce((x, y) => Multiply(x, y, f))
} else {
Multiply(others.reduce((x, y) => Multiply(x, y, f)),
Literal.create(foldableValue, m.dataType), f)
Multiply(others.reduce((x, y) => Multiply(x, y, f)), literalExpr, f)
}
} else {
m
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ReorderAssociativeOperatorSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("ReorderAssociativeOperator", Once,
Batch("ReorderAssociativeOperator", FixedPoint(10),
ConstantFolding,
ReorderAssociativeOperator) :: Nil
}

Expand All @@ -44,7 +45,7 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
($"b" + 1) * 2 * 3 * 4,
$"a" + 1 + $"b" + 2 + $"c" + 3,
$"a" + 1 + $"b" * 2 + $"c" + 3,
Rand(0) * 1 * 2 * 3 * 4)
Rand(0) * 1.0 * 2.0 * 3.0 * 4.0)

val optimized = Optimize.execute(originalQuery.analyze)

Expand All @@ -56,7 +57,7 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
(($"b" + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
($"a" + $"b" + $"c" + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
($"a" + $"b" * 2 + $"c" + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
Rand(0) * 1 * 2 * 3 * 4)
Rand(0) * 1.0 * 2.0 * 3.0 * 4.0)
.analyze

comparePlans(optimized, correctAnswer)
Expand Down Expand Up @@ -106,4 +107,17 @@ class ReorderAssociativeOperatorSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("SPARK-50380: conditional branches with error expression") {
val originalQuery1 = testRelation.select(If($"a" === 1, 1L, Literal(1).div(0) + $"b")).analyze
val optimized1 = Optimize.execute(originalQuery1)
comparePlans(optimized1, originalQuery1)

val originalQuery2 = testRelation.select(
If($"a" === 1, 1, ($"b" + Literal(Int.MaxValue)) + 1).as("col")).analyze
val optimized2 = Optimize.execute(originalQuery2)
val correctAnswer2 = testRelation.select(
If($"a" === 1, 1, $"b" + (Literal(Int.MaxValue) + 1)).as("col")).analyze
comparePlans(optimized2, correctAnswer2)
}
}

0 comments on commit d8a6075

Please sign in to comment.