@@ -866,12 +866,13 @@ impl CaseBody {
866866 // Since each when expression is tested against the base expression using the equality
867867 // operator, null base values can never match any when expression. `x = NULL` is falsy,
868868 // for all possible values of `x`.
869- if base_values. null_count ( ) > 0 {
869+ let base_null_count = base_values. logical_null_count ( ) ;
870+ if base_null_count > 0 {
870871 // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
871872 // We already checked there are nulls, so we can be sure a new buffer will not be
872873 // created.
873874 let base_not_nulls = is_not_null ( base_values. as_ref ( ) ) ?;
874- let base_all_null = base_values . null_count ( ) == remainder_batch. num_rows ( ) ;
875+ let base_all_null = base_null_count == remainder_batch. num_rows ( ) ;
875876
876877 // If there is an else expression, use that as the default value for the null rows
877878 // Otherwise the default `null` value from the result builder will be used.
@@ -1545,6 +1546,76 @@ mod tests {
15451546 Ok ( ( ) )
15461547 }
15471548
1549+ #[ test]
1550+ fn case_with_expr_dictionary ( ) -> Result < ( ) > {
1551+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Dictionary ( Box :: new( DataType :: UInt8 ) , Box :: new( DataType :: Utf8 ) ) , true ) ] ) ;
1552+ let keys = UInt8Array :: from ( vec ! [ 0u8 , 1u8 , 2u8 , 3u8 ] ) ;
1553+ let values = StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "baz" ) , None , Some ( "bar" ) ] ) ;
1554+ let dictionary = DictionaryArray :: new ( keys, Arc :: new ( values) ) ;
1555+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( dictionary) ] ) ?;
1556+
1557+ let schema = batch. schema ( ) ;
1558+
1559+ // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1560+ let when1 = lit ( "foo" ) ;
1561+ let then1 = lit ( 123i32 ) ;
1562+ let when2 = lit ( "bar" ) ;
1563+ let then2 = lit ( 456i32 ) ;
1564+
1565+ let expr = generate_case_when_with_type_coercion (
1566+ Some ( col ( "a" , & schema) ?) ,
1567+ vec ! [ ( when1, then1) , ( when2, then2) ] ,
1568+ None ,
1569+ schema. as_ref ( ) ,
1570+ ) ?;
1571+ let result = expr
1572+ . evaluate ( & batch) ?
1573+ . into_array ( batch. num_rows ( ) )
1574+ . expect ( "Failed to convert to array" ) ;
1575+ let result = as_int32_array ( & result) ?;
1576+
1577+ let expected = & Int32Array :: from ( vec ! [ Some ( 123 ) , None , None , Some ( 456 ) ] ) ;
1578+
1579+ assert_eq ! ( expected, result) ;
1580+
1581+ Ok ( ( ) )
1582+ }
1583+
1584+ #[ test]
1585+ fn case_with_expr_all_null_dictionary ( ) -> Result < ( ) > {
1586+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Dictionary ( Box :: new( DataType :: UInt8 ) , Box :: new( DataType :: Utf8 ) ) , true ) ] ) ;
1587+ let keys = UInt8Array :: from ( vec ! [ 2u8 , 2u8 , 2u8 , 2u8 ] ) ;
1588+ let values = StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "baz" ) , None , Some ( "bar" ) ] ) ;
1589+ let dictionary = DictionaryArray :: new ( keys, Arc :: new ( values) ) ;
1590+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( dictionary) ] ) ?;
1591+
1592+ let schema = batch. schema ( ) ;
1593+
1594+ // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1595+ let when1 = lit ( "foo" ) ;
1596+ let then1 = lit ( 123i32 ) ;
1597+ let when2 = lit ( "bar" ) ;
1598+ let then2 = lit ( 456i32 ) ;
1599+
1600+ let expr = generate_case_when_with_type_coercion (
1601+ Some ( col ( "a" , & schema) ?) ,
1602+ vec ! [ ( when1, then1) , ( when2, then2) ] ,
1603+ None ,
1604+ schema. as_ref ( ) ,
1605+ ) ?;
1606+ let result = expr
1607+ . evaluate ( & batch) ?
1608+ . into_array ( batch. num_rows ( ) )
1609+ . expect ( "Failed to convert to array" ) ;
1610+ let result = as_int32_array ( & result) ?;
1611+
1612+ let expected = & Int32Array :: from ( vec ! [ None , None , None , None ] ) ;
1613+
1614+ assert_eq ! ( expected, result) ;
1615+
1616+ Ok ( ( ) )
1617+ }
1618+
15481619 #[ test]
15491620 fn case_with_expr_else ( ) -> Result < ( ) > {
15501621 let batch = case_test_batch ( ) ?;
0 commit comments