Skip to content

Commit 8946f8b

Browse files
wjones127alamb
andauthored
feat: add guarantees to simplification (#7467)
* feat: add guarantees to simplifcation * null and comparison support * add support for literal expressions * implement inlist guarantee use * test the outer function * docs * refactor to use intervals * add high-level test * cleanup * fix test to be false or null, not true * refactor: change NullableInterval to an enum * refactor: use a builder-like API * pr feedback * Fix clippy * fix doc links --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 812864b commit 8946f8b

File tree

4 files changed

+1006
-3
lines changed

4 files changed

+1006
-3
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 174 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,20 @@ use datafusion_expr::{
3939
and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr,
4040
Like, Volatility,
4141
};
42-
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
42+
use datafusion_physical_expr::{
43+
create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval,
44+
};
4345

4446
use crate::simplify_expressions::SimplifyInfo;
4547

48+
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
49+
4650
/// This structure handles API for expression simplification
4751
pub struct ExprSimplifier<S> {
4852
info: S,
53+
/// Guarantees about the values of columns. This is provided by the user
54+
/// in [ExprSimplifier::with_guarantees()].
55+
guarantees: Vec<(Expr, NullableInterval)>,
4956
}
5057

5158
pub const THRESHOLD_INLINE_INLIST: usize = 3;
@@ -57,7 +64,10 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
5764
///
5865
/// [`SimplifyContext`]: crate::simplify_expressions::context::SimplifyContext
5966
pub fn new(info: S) -> Self {
60-
Self { info }
67+
Self {
68+
info,
69+
guarantees: vec![],
70+
}
6171
}
6272

6373
/// Simplifies this [`Expr`]`s as much as possible, evaluating
@@ -121,6 +131,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
121131
let mut simplifier = Simplifier::new(&self.info);
122132
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
123133
let mut or_in_list_simplifier = OrInListSimplifier::new();
134+
let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
124135

125136
// TODO iterate until no changes are made during rewrite
126137
// (evaluating constants can enable new simplifications and
@@ -129,6 +140,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
129140
expr.rewrite(&mut const_evaluator)?
130141
.rewrite(&mut simplifier)?
131142
.rewrite(&mut or_in_list_simplifier)?
143+
.rewrite(&mut guarantee_rewriter)?
132144
// run both passes twice to try an minimize simplifications that we missed
133145
.rewrite(&mut const_evaluator)?
134146
.rewrite(&mut simplifier)
@@ -149,6 +161,65 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
149161

150162
expr.rewrite(&mut expr_rewrite)
151163
}
164+
165+
/// Input guarantees about the values of columns.
166+
///
167+
/// The guarantees can simplify expressions. For example, if a column `x` is
168+
/// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
169+
/// literal `true`.
170+
///
171+
/// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
172+
/// where the [Expr] is a column reference and the [NullableInterval]
173+
/// is an interval representing the known possible values of that column.
174+
///
175+
/// ```rust
176+
/// use arrow::datatypes::{DataType, Field, Schema};
177+
/// use datafusion_expr::{col, lit, Expr};
178+
/// use datafusion_common::{Result, ScalarValue, ToDFSchema};
179+
/// use datafusion_physical_expr::execution_props::ExecutionProps;
180+
/// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
181+
/// use datafusion_optimizer::simplify_expressions::{
182+
/// ExprSimplifier, SimplifyContext};
183+
///
184+
/// let schema = Schema::new(vec![
185+
/// Field::new("x", DataType::Int64, false),
186+
/// Field::new("y", DataType::UInt32, false),
187+
/// Field::new("z", DataType::Int64, false),
188+
/// ])
189+
/// .to_dfschema_ref().unwrap();
190+
///
191+
/// // Create the simplifier
192+
/// let props = ExecutionProps::new();
193+
/// let context = SimplifyContext::new(&props)
194+
/// .with_schema(schema);
195+
///
196+
/// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
197+
/// let expr_x = col("x").gt_eq(lit(3_i64));
198+
/// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
199+
/// let expr_z = col("z").gt(lit(5_i64));
200+
/// let expr = expr_x.and(expr_y).and(expr_z.clone());
201+
///
202+
/// let guarantees = vec![
203+
/// // x ∈ [3, 5]
204+
/// (
205+
/// col("x"),
206+
/// NullableInterval::NotNull {
207+
/// values: Interval::make(Some(3_i64), Some(5_i64), (false, false)),
208+
/// }
209+
/// ),
210+
/// // y = 3
211+
/// (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
212+
/// ];
213+
/// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
214+
/// let output = simplifier.simplify(expr).unwrap();
215+
/// // Expression becomes: true AND true AND (z > 5), which simplifies to
216+
/// // z > 5.
217+
/// assert_eq!(output, expr_z);
218+
/// ```
219+
pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
220+
self.guarantees = guarantees;
221+
self
222+
}
152223
}
153224

154225
#[allow(rustdoc::private_intra_doc_links)]
@@ -1239,7 +1310,9 @@ mod tests {
12391310
use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema};
12401311
use datafusion_expr::*;
12411312
use datafusion_physical_expr::{
1242-
execution_props::ExecutionProps, functions::make_scalar_function,
1313+
execution_props::ExecutionProps,
1314+
functions::make_scalar_function,
1315+
intervals::{Interval, NullableInterval},
12431316
};
12441317

12451318
// ------------------------------
@@ -2703,6 +2776,19 @@ mod tests {
27032776
try_simplify(expr).unwrap()
27042777
}
27052778

2779+
fn simplify_with_guarantee(
2780+
expr: Expr,
2781+
guarantees: Vec<(Expr, NullableInterval)>,
2782+
) -> Expr {
2783+
let schema = expr_test_schema();
2784+
let execution_props = ExecutionProps::new();
2785+
let simplifier = ExprSimplifier::new(
2786+
SimplifyContext::new(&execution_props).with_schema(schema),
2787+
)
2788+
.with_guarantees(guarantees);
2789+
simplifier.simplify(expr).unwrap()
2790+
}
2791+
27062792
fn expr_test_schema() -> DFSchemaRef {
27072793
Arc::new(
27082794
DFSchema::new_with_metadata(
@@ -3166,4 +3252,89 @@ mod tests {
31663252
let expr = not_ilike(null, "%");
31673253
assert_eq!(simplify(expr), lit_bool_null());
31683254
}
3255+
3256+
#[test]
3257+
fn test_simplify_with_guarantee() {
3258+
// (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
3259+
let expr_x = col("c3").gt(lit(3_i64));
3260+
let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
3261+
let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
3262+
let expr = expr_x.clone().and(expr_y.clone().or(expr_z));
3263+
3264+
// All guaranteed null
3265+
let guarantees = vec![
3266+
(col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
3267+
(col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
3268+
(col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
3269+
];
3270+
3271+
let output = simplify_with_guarantee(expr.clone(), guarantees);
3272+
assert_eq!(output, lit_bool_null());
3273+
3274+
// All guaranteed false
3275+
let guarantees = vec![
3276+
(
3277+
col("c3"),
3278+
NullableInterval::NotNull {
3279+
values: Interval::make(Some(0_i64), Some(2_i64), (false, false)),
3280+
},
3281+
),
3282+
(
3283+
col("c4"),
3284+
NullableInterval::from(ScalarValue::UInt32(Some(9))),
3285+
),
3286+
(
3287+
col("c1"),
3288+
NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))),
3289+
),
3290+
];
3291+
let output = simplify_with_guarantee(expr.clone(), guarantees);
3292+
assert_eq!(output, lit(false));
3293+
3294+
// Guaranteed false or null -> no change.
3295+
let guarantees = vec![
3296+
(
3297+
col("c3"),
3298+
NullableInterval::MaybeNull {
3299+
values: Interval::make(Some(0_i64), Some(2_i64), (false, false)),
3300+
},
3301+
),
3302+
(
3303+
col("c4"),
3304+
NullableInterval::MaybeNull {
3305+
values: Interval::make(Some(9_u32), Some(9_u32), (false, false)),
3306+
},
3307+
),
3308+
(
3309+
col("c1"),
3310+
NullableInterval::NotNull {
3311+
values: Interval::make(Some("d"), Some("f"), (false, false)),
3312+
},
3313+
),
3314+
];
3315+
let output = simplify_with_guarantee(expr.clone(), guarantees);
3316+
assert_eq!(&output, &expr_x);
3317+
3318+
// Sufficient true guarantees
3319+
let guarantees = vec![
3320+
(
3321+
col("c3"),
3322+
NullableInterval::from(ScalarValue::Int64(Some(9))),
3323+
),
3324+
(
3325+
col("c4"),
3326+
NullableInterval::from(ScalarValue::UInt32(Some(3))),
3327+
),
3328+
];
3329+
let output = simplify_with_guarantee(expr.clone(), guarantees);
3330+
assert_eq!(output, lit(true));
3331+
3332+
// Only partially simplify
3333+
let guarantees = vec![(
3334+
col("c4"),
3335+
NullableInterval::from(ScalarValue::UInt32(Some(3))),
3336+
)];
3337+
let output = simplify_with_guarantee(expr.clone(), guarantees);
3338+
assert_eq!(&output, &expr_x);
3339+
}
31693340
}

0 commit comments

Comments
 (0)