Skip to content

Commit 6c4256a

Browse files
committed
Refactor UnwrapCastInComparison to use rewrite()
1 parent a165b7f commit 6c4256a

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

datafusion/expr/src/expr_rewriter/mod.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,16 @@ pub fn unalias(expr: Expr) -> Expr {
276276
///
277277
/// This is important when optimizing plans to ensure the output
278278
/// schema of plan nodes don't change after optimization
279-
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
279+
pub fn rewrite_preserving_name<R>(
280+
expr: Expr,
281+
rewriter: &mut R,
282+
) -> Result<Transformed<Expr>>
280283
where
281284
R: TreeNodeRewriter<Node = Expr>,
282285
{
283286
let original_name = expr.name_for_alias()?;
284-
let expr = expr.rewrite(rewriter)?.data;
285-
expr.alias_if_changed(original_name)
287+
expr.rewrite(rewriter)?
288+
.map_data(|expr| expr.alias_if_changed(original_name))
286289
}
287290

288291
#[cfg(test)]
@@ -478,7 +481,9 @@ mod test {
478481
let mut rewriter = TestRewriter {
479482
rewrite_to: rewrite_to.clone(),
480483
};
481-
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();
484+
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter)
485+
.data()
486+
.unwrap();
482487

483488
let original_name = match &expr_from {
484489
Expr::Sort(Sort { expr, .. }) => expr.display_name(),

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::sync::Arc;
2222
use arrow::datatypes::{DataType, IntervalUnit};
2323

2424
use datafusion_common::config::ConfigOptions;
25-
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
25+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNodeRewriter};
2626
use datafusion_common::{
2727
exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef,
2828
DataFusionError, Result, ScalarValue,
@@ -109,7 +109,7 @@ fn analyze_internal(
109109
.map(|expr| {
110110
// ensure aggregate names don't change:
111111
// https://github.com/apache/arrow-datafusion/issues/3555
112-
rewrite_preserving_name(expr, &mut expr_rewrite)
112+
rewrite_preserving_name(expr, &mut expr_rewrite).data()
113113
})
114114
.collect::<Result<Vec<_>>>()?;
115115

datafusion/optimizer/src/unwrap_cast_in_comparison.rs

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ use std::sync::Arc;
2323
use crate::optimizer::ApplyOrder;
2424
use crate::{OptimizerConfig, OptimizerRule};
2525

26+
use crate::utils::NamePreserver;
2627
use arrow::datatypes::{
2728
DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
2829
};
2930
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
30-
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
31+
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
3132
use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue};
3233
use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
3334
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
@@ -85,12 +86,32 @@ impl UnwrapCastInComparison {
8586
impl OptimizerRule for UnwrapCastInComparison {
8687
fn try_optimize(
8788
&self,
88-
plan: &LogicalPlan,
89+
_plan: &LogicalPlan,
8990
_config: &dyn OptimizerConfig,
9091
) -> Result<Option<LogicalPlan>> {
92+
internal_err!("Should have called UnwrapCastInComparison::rewrite")
93+
}
94+
95+
fn name(&self) -> &str {
96+
"unwrap_cast_in_comparison"
97+
}
98+
99+
fn apply_order(&self) -> Option<ApplyOrder> {
100+
Some(ApplyOrder::BottomUp)
101+
}
102+
103+
fn supports_rewrite(&self) -> bool {
104+
true
105+
}
106+
107+
fn rewrite(
108+
&self,
109+
plan: LogicalPlan,
110+
_config: &dyn OptimizerConfig,
111+
) -> Result<Transformed<LogicalPlan>> {
91112
let mut schema = merge_schema(plan.inputs());
92113

93-
if let LogicalPlan::TableScan(ts) = plan {
114+
if let LogicalPlan::TableScan(ts) = &plan {
94115
let source_schema = DFSchema::try_from_qualified_schema(
95116
ts.table_name.clone(),
96117
&ts.source.schema(),
@@ -104,22 +125,12 @@ impl OptimizerRule for UnwrapCastInComparison {
104125
schema: Arc::new(schema),
105126
};
106127

107-
let new_exprs = plan
108-
.expressions()
109-
.into_iter()
110-
.map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
111-
.collect::<Result<Vec<_>>>()?;
112-
113-
let inputs = plan.inputs().into_iter().cloned().collect();
114-
plan.with_new_exprs(new_exprs, inputs).map(Some)
115-
}
116-
117-
fn name(&self) -> &str {
118-
"unwrap_cast_in_comparison"
119-
}
120-
121-
fn apply_order(&self) -> Option<ApplyOrder> {
122-
Some(ApplyOrder::BottomUp)
128+
let name_preserver = NamePreserver::new(&plan);
129+
plan.map_expressions(|expr| {
130+
let original_name = name_preserver.save(&expr)?;
131+
expr.rewrite(&mut expr_rewriter)?
132+
.map_data(|expr| original_name.restore(expr))
133+
})
123134
}
124135
}
125136

0 commit comments

Comments
 (0)