Skip to content

Commit e37d026

Browse files
committed
Use logical null count in case_when_with_expr
1 parent 0bd127f commit e37d026

File tree

1 file changed

+73
-2
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+73
-2
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,12 +866,13 @@ impl CaseBody {
866866
// Since each when expression is tested against the base expression using the equality
867867
// operator, null base values can never match any when expression. `x = NULL` is falsy,
868868
// for all possible values of `x`.
869-
if base_values.null_count() > 0 {
869+
let base_null_count = base_values.logical_null_count();
870+
if base_null_count > 0 {
870871
// Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
871872
// We already checked there are nulls, so we can be sure a new buffer will not be
872873
// created.
873874
let base_not_nulls = is_not_null(base_values.as_ref())?;
874-
let base_all_null = base_values.null_count() == remainder_batch.num_rows();
875+
let base_all_null = base_null_count == remainder_batch.num_rows();
875876

876877
// If there is an else expression, use that as the default value for the null rows
877878
// Otherwise the default `null` value from the result builder will be used.
@@ -1545,6 +1546,76 @@ mod tests {
15451546
Ok(())
15461547
}
15471548

1549+
#[test]
1550+
fn case_with_expr_dictionary() -> Result<()> {
1551+
let schema = Schema::new(vec![Field::new("a", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), true)]);
1552+
let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1553+
let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1554+
let dictionary = DictionaryArray::new(keys, Arc::new(values));
1555+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1556+
1557+
let schema = batch.schema();
1558+
1559+
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1560+
let when1 = lit("foo");
1561+
let then1 = lit(123i32);
1562+
let when2 = lit("bar");
1563+
let then2 = lit(456i32);
1564+
1565+
let expr = generate_case_when_with_type_coercion(
1566+
Some(col("a", &schema)?),
1567+
vec![(when1, then1), (when2, then2)],
1568+
None,
1569+
schema.as_ref(),
1570+
)?;
1571+
let result = expr
1572+
.evaluate(&batch)?
1573+
.into_array(batch.num_rows())
1574+
.expect("Failed to convert to array");
1575+
let result = as_int32_array(&result)?;
1576+
1577+
let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1578+
1579+
assert_eq!(expected, result);
1580+
1581+
Ok(())
1582+
}
1583+
1584+
#[test]
1585+
fn case_with_expr_all_null_dictionary() -> Result<()> {
1586+
let schema = Schema::new(vec![Field::new("a", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), true)]);
1587+
let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1588+
let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1589+
let dictionary = DictionaryArray::new(keys, Arc::new(values));
1590+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1591+
1592+
let schema = batch.schema();
1593+
1594+
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1595+
let when1 = lit("foo");
1596+
let then1 = lit(123i32);
1597+
let when2 = lit("bar");
1598+
let then2 = lit(456i32);
1599+
1600+
let expr = generate_case_when_with_type_coercion(
1601+
Some(col("a", &schema)?),
1602+
vec![(when1, then1), (when2, then2)],
1603+
None,
1604+
schema.as_ref(),
1605+
)?;
1606+
let result = expr
1607+
.evaluate(&batch)?
1608+
.into_array(batch.num_rows())
1609+
.expect("Failed to convert to array");
1610+
let result = as_int32_array(&result)?;
1611+
1612+
let expected = &Int32Array::from(vec![None, None, None, None]);
1613+
1614+
assert_eq!(expected, result);
1615+
1616+
Ok(())
1617+
}
1618+
15481619
#[test]
15491620
fn case_with_expr_else() -> Result<()> {
15501621
let batch = case_test_batch()?;

0 commit comments

Comments
 (0)