Skip to content

Commit

Permalink
move is_volatile() check out of visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Jun 21, 2024
1 parent be59824 commit fc71133
Showing 1 changed file with 41 additions and 33 deletions.
74 changes: 41 additions & 33 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,25 @@ impl CommonSubexprEliminate {
id_array: &mut IdArray<'n>,
expr_mask: ExprMask,
) -> Result<bool> {
let mut visitor = ExprIdentifierVisitor {
expr_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
expr_mask,
random_state: &self.random_state,
found_common: false,
};
expr.visit(&mut visitor)?;
// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
Ok(if expr.is_volatile()? {
false
} else {
let mut visitor = ExprIdentifierVisitor {
expr_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
expr_mask,
random_state: &self.random_state,
found_common: false,
};
expr.visit(&mut visitor)?;

Ok(visitor.found_common)
visitor.found_common
})
}

/// Rewrites `exprs_list` with common sub-expressions replaced with a new
Expand Down Expand Up @@ -950,11 +956,9 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
type Node = Expr;

fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
// TODO: propagate is_volatile state bottom-up + consider non-volatile sub-expressions for CSE
// TODO: consider non-volatile sub-expressions for CSE
// TODO: consider surely executed children of "short circuited"s for CSE
if expr.short_circuits() || expr.is_volatile()? {
if expr.short_circuits() {
self.visit_stack.push(VisitRecord::JumpMark);

return Ok(TreeNodeRecursion::Jump);
Expand Down Expand Up @@ -1013,14 +1017,6 @@ struct CommonSubexprRewriter<'a, 'n> {
impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
type Node = Expr;

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
if matches!(expr, Expr::Alias(_)) {
self.alias_counter -= 1
}

Ok(Transformed::no(expr))
}

fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
if matches!(expr, Expr::Alias(_)) {
self.alias_counter += 1;
Expand All @@ -1029,7 +1025,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if expr.short_circuits() || expr.is_volatile()? {
if expr.short_circuits() {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}

Expand Down Expand Up @@ -1069,6 +1065,14 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
Ok(Transformed::no(expr))
}
}

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
if matches!(expr, Expr::Alias(_)) {
self.alias_counter -= 1
}

Ok(Transformed::no(expr))
}
}

/// Replace common sub-expression in `expr` with the corresponding temporary
Expand All @@ -1080,14 +1084,18 @@ fn replace_common_expr<'n>(
common_exprs: &mut CommonExprs<'n>,
alias_generator: &AliasGenerator,
) -> Result<Transformed<Expr>> {
expr.rewrite(&mut CommonSubexprRewriter {
expr_stats,
id_array,
common_exprs,
down_index: 0,
alias_counter: 0,
alias_generator,
})
if id_array.is_empty() {
Ok(Transformed::no(expr))
} else {
expr.rewrite(&mut CommonSubexprRewriter {
expr_stats,
id_array,
common_exprs,
down_index: 0,
alias_counter: 0,
alias_generator,
})
}
}

#[cfg(test)]
Expand Down

0 comments on commit fc71133

Please sign in to comment.