diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md index f77210371..ac4835a4e 100644 --- a/.ai/skills/check-upstream/SKILL.md +++ b/.ai/skills/check-upstream/SKILL.md @@ -109,6 +109,7 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all" **Evaluated and not requiring separate Python exposure:** - `show_limit` — already covered by `DataFrame.show()`, which provides the same functionality with a simpler API - `with_param_values` — already covered by the `param_values` argument on `SessionContext.sql()`, which accomplishes the same thing more robustly +- `union_by_name_distinct` — already covered by `DataFrame.union_by_name(distinct=True)`, which provides a more Pythonic API **How to check:** 1. Fetch the upstream DataFrame documentation page listing all methods diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index 72595ba81..fff5118d5 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -582,6 +582,14 @@ impl PyDataFrame { Ok(Self::new(df)) } + /// Apply window function expressions to the DataFrame + #[pyo3(signature = (*exprs))] + fn window(&self, exprs: Vec) -> PyDataFusionResult { + let window_exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.as_ref().clone().window(window_exprs)?; + Ok(Self::new(df)) + } + fn filter(&self, predicate: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().filter(predicate.into())?; Ok(Self::new(df)) @@ -804,9 +812,27 @@ impl PyDataFrame { } /// Print the query plan - #[pyo3(signature = (verbose=false, analyze=false))] - fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> { - let df = self.df.as_ref().clone().explain(verbose, analyze)?; + #[pyo3(signature = (verbose=false, analyze=false, format=None))] + fn explain( + &self, + py: Python, + verbose: bool, + analyze: bool, + format: Option<&str>, + ) -> PyDataFusionResult<()> { + let explain_format = match format { + Some(f) => f + .parse::() + .map_err(|e| { + PyDataFusionError::Common(format!("Invalid explain format '{}': {}", f, e)) + })?, + None => datafusion::common::format::ExplainFormat::Indent, + }; + let opts = datafusion::logical_expr::ExplainOption::default() + .with_verbose(verbose) + .with_analyze(analyze) + .with_format(explain_format); + let df = self.df.as_ref().clone().explain_with_options(opts)?; print_dataframe(py, df) } @@ -864,22 +890,14 @@ impl PyDataFrame { Ok(Self::new(new_df)) } - /// Calculate the distinct union of two `DataFrame`s. The - /// two `DataFrame`s must have exactly the same schema - fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self - .df - .as_ref() - .clone() - .union_distinct(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) - } - - #[pyo3(signature = (column, preserve_nulls=true))] - fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult { - // TODO: expose RecursionUnnestOptions - // REF: https://github.com/apache/datafusion/pull/11577 - let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + #[pyo3(signature = (column, preserve_nulls=true, recursions=None))] + fn unnest_column( + &self, + column: &str, + preserve_nulls: bool, + recursions: Option>, + ) -> PyDataFusionResult { + let unnest_options = build_unnest_options(preserve_nulls, recursions); let df = self .df .as_ref() @@ -888,15 +906,14 @@ impl PyDataFrame { Ok(Self::new(df)) } - #[pyo3(signature = (columns, preserve_nulls=true))] + #[pyo3(signature = (columns, preserve_nulls=true, recursions=None))] fn unnest_columns( &self, columns: Vec, preserve_nulls: bool, + recursions: Option>, ) -> PyDataFusionResult { - // TODO: expose RecursionUnnestOptions - // REF: https://github.com/apache/datafusion/pull/11577 - let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + let unnest_options = build_unnest_options(preserve_nulls, recursions); let cols = columns.iter().map(|s| s.as_ref()).collect::>(); let df = self .df @@ -907,21 +924,79 @@ impl PyDataFrame { } /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self - .df - .as_ref() - .clone() - .intersect(py_df.df.as_ref().clone())?; + #[pyo3(signature = (py_df, distinct=false))] + fn intersect(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.intersect_distinct(other)? + } else { + base.intersect(other)? + }; Ok(Self::new(new_df)) } /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?; + #[pyo3(signature = (py_df, distinct=false))] + fn except_all(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.except_distinct(other)? + } else { + base.except(other)? + }; Ok(Self::new(new_df)) } + /// Union two DataFrames matching columns by name + #[pyo3(signature = (py_df, distinct=false))] + fn union_by_name(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.union_by_name_distinct(other)? + } else { + base.union_by_name(other)? + }; + Ok(Self::new(new_df)) + } + + /// Deduplicate rows based on specific columns, keeping the first row per group + fn distinct_on( + &self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> PyDataFusionResult { + let on_expr = on_expr.into_iter().map(|e| e.into()).collect(); + let select_expr = select_expr.into_iter().map(|e| e.into()).collect(); + let sort_expr = sort_expr.map(to_sort_expressions); + let df = self + .df + .as_ref() + .clone() + .distinct_on(on_expr, select_expr, sort_expr)?; + Ok(Self::new(df)) + } + + /// Sort by column expressions with ascending order and nulls last + fn sort_by(&self, exprs: Vec) -> PyDataFusionResult { + let exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.as_ref().clone().sort_by(exprs)?; + Ok(Self::new(df)) + } + + /// Return fully qualified column expressions for the given column names + fn find_qualified_columns(&self, names: Vec) -> PyDataFusionResult> { + let name_refs: Vec<&str> = names.iter().map(|s| s.as_str()).collect(); + let qualified = self.df.find_qualified_columns(&name_refs)?; + Ok(qualified + .into_iter() + .map(|q| Expr::Column(Column::from(q)).into()) + .collect()) + } + /// Write a `DataFrame` to a CSV file. fn write_csv( &self, @@ -1295,6 +1370,26 @@ impl PyDataFrameWriteOptions { } } +fn build_unnest_options( + preserve_nulls: bool, + recursions: Option>, +) -> UnnestOptions { + let mut opts = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + if let Some(recs) = recursions { + opts.recursions = recs + .into_iter() + .map( + |(input, output, depth)| datafusion::common::RecursionUnnestOption { + input_column: datafusion::common::Column::from(input.as_str()), + output_column: datafusion::common::Column::from(output.as_str()), + depth, + }, + ) + .collect(); + } + opts +} + /// Print DataFrame fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> { // Get string representation of record batches diff --git a/docs/source/user-guide/common-operations/joins.rst b/docs/source/user-guide/common-operations/joins.rst index 1d9d70385..a289c9377 100644 --- a/docs/source/user-guide/common-operations/joins.rst +++ b/docs/source/user-guide/common-operations/joins.rst @@ -134,3 +134,36 @@ In contrast to the above example, if we wish to get both columns: .. ipython:: python left.join(right, "id", how="inner", coalesce_duplicate_keys=False) + +Disambiguating Columns with ``DataFrame.col()`` +------------------------------------------------ + +When both DataFrames contain non-key columns with the same name, you can use +:py:meth:`~datafusion.dataframe.DataFrame.col` on each DataFrame **before** the +join to create fully qualified column references. These references can then be +used in the join predicate and when selecting from the result. + +This is especially useful with :py:meth:`~datafusion.dataframe.DataFrame.join_on`, +which accepts expression-based predicates. + +.. ipython:: python + + left = ctx.from_pydict( + { + "id": [1, 2, 3], + "val": [10, 20, 30], + } + ) + + right = ctx.from_pydict( + { + "id": [1, 2, 3], + "val": [40, 50, 60], + } + ) + + joined = left.join_on( + right, left.col("id") == right.col("id"), how="inner" + ) + + joined.select(left.col("id"), left.col("val"), right.col("val")) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 2e6f81166..a736c3966 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -47,6 +47,7 @@ from .dataframe import ( DataFrame, DataFrameWriteOptions, + ExplainFormat, InsertOp, ParquetColumnOptions, ParquetWriterOptions, @@ -82,6 +83,7 @@ "DataFrameWriteOptions", "Database", "ExecutionPlan", + "ExplainFormat", "Expr", "InsertOp", "LogicalPlan", diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 9907eae8b..9dc5f0e7d 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -44,6 +44,7 @@ Expr, SortExpr, SortKey, + _to_raw_expr, ensure_expr, ensure_expr_list, expr_list_to_raw_expr_list, @@ -65,6 +66,25 @@ from enum import Enum +class ExplainFormat(Enum): + """Output format for explain plans. + + Controls how the query plan is rendered in :py:meth:`DataFrame.explain`. + """ + + INDENT = "indent" + """Default indented text format.""" + + TREE = "tree" + """Tree-style visual format with box-drawing characters.""" + + PGJSON = "pgjson" + """PostgreSQL-compatible JSON format for use with visualization tools.""" + + GRAPHVIZ = "graphviz" + """Graphviz DOT format for graph rendering.""" + + # excerpt from deltalake # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 class Compression(Enum): @@ -395,6 +415,80 @@ def schema(self) -> pa.Schema: """ return self.df.schema() + def column(self, name: str) -> Expr: + """Return a fully qualified column expression for ``name``. + + Resolves an unqualified column name against this DataFrame's schema + and returns an :py:class:`Expr` whose underlying column reference + includes the table qualifier. This is especially useful after joins, + where the same column name may appear in multiple relations. + + Args: + name: Unqualified column name to look up. + + Returns: + A fully qualified column expression. + + Raises: + Exception: If the column is not found or is ambiguous (exists in + multiple relations). + + Examples: + Resolve a column from a simple DataFrame: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2], "b": [3, 4]}) + >>> expr = df.column("a") + >>> df.select(expr).to_pydict() + {'a': [1, 2]} + + Resolve qualified columns after a join: + + >>> left = ctx.from_pydict({"id": [1, 2], "x": [10, 20]}) + >>> right = ctx.from_pydict({"id": [1, 2], "y": [30, 40]}) + >>> joined = left.join(right, on="id", how="inner") + >>> expr = joined.column("y") + >>> joined.select("id", expr).sort("id").to_pydict() + {'id': [1, 2], 'y': [30, 40]} + """ + return self.find_qualified_columns(name)[0] + + def col(self, name: str) -> Expr: + """Alias for :py:meth:`column`. + + See Also: + :py:meth:`column` + """ + return self.column(name) + + def find_qualified_columns(self, *names: str) -> list[Expr]: + """Return fully qualified column expressions for the given names. + + This is a batch version of :py:meth:`column` — it resolves each + unqualified name against the DataFrame's schema and returns a list + of qualified column expressions. + + Args: + names: Unqualified column names to look up. + + Returns: + List of fully qualified column expressions, one per name. + + Raises: + Exception: If any column is not found or is ambiguous. + + Examples: + Resolve multiple columns at once: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + >>> exprs = df.find_qualified_columns("a", "c") + >>> df.select(*exprs).to_pydict() + {'a': [1, 2], 'c': [5, 6]} + """ + raw_exprs = self.df.find_qualified_columns(list(names)) + return [Expr(e) for e in raw_exprs] + @deprecated( "select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead" ) @@ -468,6 +562,36 @@ def drop(self, *columns: str) -> DataFrame: """ return DataFrame(self.df.drop(*columns)) + def window(self, *exprs: Expr) -> DataFrame: + """Add window function columns to the DataFrame. + + Applies the given window function expressions and appends the results + as new columns. + + Args: + exprs: Window function expressions to evaluate. + + Returns: + DataFrame with new window function columns appended. + + Examples: + Add a row number within each group: + + >>> import datafusion.functions as f + >>> from datafusion import col + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "x", "y"]}) + >>> df = df.window( + ... f.row_number( + ... partition_by=[col("b")], order_by=[col("a")] + ... ).alias("rn") + ... ) + >>> "rn" in df.schema().names + True + """ + raw = expr_list_to_raw_expr_list(exprs) + return DataFrame(self.df.window(*raw)) + def filter(self, *predicates: Expr | str) -> DataFrame: """Return a DataFrame for which ``predicate`` evaluates to ``True``. @@ -837,7 +961,13 @@ def join( ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. - `on` has to be provided or both `left_on` and `right_on` in conjunction. + ``on`` has to be provided or both ``left_on`` and ``right_on`` in + conjunction. + + When non-key columns share the same name in both DataFrames, use + :py:meth:`DataFrame.col` on each DataFrame **before** the join to + obtain fully qualified column references that can disambiguate them. + See :py:meth:`join_on` for an example. Args: right: Other DataFrame to join with. @@ -911,7 +1041,14 @@ def join_on( built with :func:`datafusion.col`. On expressions are used to support in-equality predicates. Equality predicates are correctly optimized. + Use :py:meth:`DataFrame.col` on each DataFrame **before** the join to + obtain fully qualified column references. These qualified references + can then be used in the join predicate and to disambiguate columns + with the same name when selecting from the result. + Examples: + Join with unique column names: + >>> ctx = dfn.SessionContext() >>> left = ctx.from_pydict({"a": [1, 2], "x": ["a", "b"]}) >>> right = ctx.from_pydict({"b": [1, 2], "y": ["c", "d"]}) @@ -920,6 +1057,18 @@ def join_on( ... ).sort(col("x")).to_pydict() {'a': [1, 2], 'x': ['a', 'b'], 'b': [1, 2], 'y': ['c', 'd']} + Use :py:meth:`col` to disambiguate shared column names: + + >>> left = ctx.from_pydict({"id": [1, 2], "val": [10, 20]}) + >>> right = ctx.from_pydict({"id": [1, 2], "val": [30, 40]}) + >>> joined = left.join_on( + ... right, left.col("id") == right.col("id"), how="inner" + ... ) + >>> joined.select( + ... left.col("id"), left.col("val"), right.col("val").alias("rval") + ... ).sort(left.col("id")).to_pydict() + {'id': [1, 2], 'val': [10, 20], 'rval': [30, 40]} + Args: right: Other DataFrame to join with. on_exprs: single or multiple (in)-equality predicates. @@ -932,7 +1081,12 @@ def join_on( exprs = [ensure_expr(expr) for expr in on_exprs] return DataFrame(self.df.join_on(right.df, exprs, how)) - def explain(self, verbose: bool = False, analyze: bool = False) -> None: + def explain( + self, + verbose: bool = False, + analyze: bool = False, + format: ExplainFormat | None = None, + ) -> None: """Print an explanation of the DataFrame's plan so far. If ``analyze`` is specified, runs the plan and reports metrics. @@ -940,8 +1094,23 @@ def explain(self, verbose: bool = False, analyze: bool = False) -> None: Args: verbose: If ``True``, more details will be included. analyze: If ``True``, the plan will run and metrics reported. + format: Output format for the plan. Defaults to + :py:attr:`ExplainFormat.INDENT`. + + Examples: + Show the plan in tree format: + + >>> from datafusion import ExplainFormat + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.explain(format=ExplainFormat.TREE) # doctest: +SKIP + + Show plan with runtime metrics: + + >>> df.explain(analyze=True) # doctest: +SKIP """ - self.df.explain(verbose, analyze) + fmt = format.value if format is not None else None + self.df.explain(verbose, analyze, fmt) def logical_plan(self) -> LogicalPlan: """Return the unoptimized ``LogicalPlan``. @@ -1010,45 +1179,170 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: """ return DataFrame(self.df.union(other.df, distinct)) + @deprecated( + "union_distinct() is deprecated. Use union(other, distinct=True) instead." + ) def union_distinct(self, other: DataFrame) -> DataFrame: """Calculate the distinct union of two :py:class:`DataFrame`. + See Also: + :py:meth:`union` + """ + return self.union(other, distinct=True) + + def intersect(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the intersection of two :py:class:`DataFrame`. + The two :py:class:`DataFrame` must have exactly the same schema. - Any duplicate rows are discarded. Args: - other: DataFrame to union with. + other: DataFrame to intersect with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after union. + DataFrame after intersection. + + Examples: + Find rows common to both DataFrames: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df2 = ctx.from_pydict({"a": [1, 4], "b": [10, 40]}) + >>> df1.intersect(df2).to_pydict() + {'a': [1], 'b': [10]} + + Intersect with deduplication: + + >>> df1 = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 10, 20]}) + >>> df2 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + >>> df1.intersect(df2, distinct=True).to_pydict() + {'a': [1], 'b': [10]} """ - return DataFrame(self.df.union_distinct(other.df)) + return DataFrame(self.df.intersect(other.df, distinct)) - def intersect(self, other: DataFrame) -> DataFrame: - """Calculate the intersection of two :py:class:`DataFrame`. + def except_all(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the set difference of two :py:class:`DataFrame`. + + Returns rows that are in this DataFrame but not in ``other``. The two :py:class:`DataFrame` must have exactly the same schema. Args: - other: DataFrame to intersect with. + other: DataFrame to calculate exception with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after intersection. + DataFrame after set difference. + + Examples: + Remove rows present in ``df2``: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df2 = ctx.from_pydict({"a": [1, 2], "b": [10, 20]}) + >>> df1.except_all(df2).sort("a").to_pydict() + {'a': [3], 'b': [30]} + + Remove rows present in ``df2`` and deduplicate: + + >>> df1.except_all(df2, distinct=True).sort("a").to_pydict() + {'a': [3], 'b': [30]} """ - return DataFrame(self.df.intersect(other.df)) + return DataFrame(self.df.except_all(other.df, distinct)) - def except_all(self, other: DataFrame) -> DataFrame: - """Calculate the exception of two :py:class:`DataFrame`. + def union_by_name(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Union two :py:class:`DataFrame` matching columns by name. - The two :py:class:`DataFrame` must have exactly the same schema. + Unlike :py:meth:`union` which matches columns by position, this method + matches columns by their names, allowing DataFrames with different + column orders to be combined. Args: - other: DataFrame to calculate exception with. + other: DataFrame to union with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after exception. + DataFrame after union by name. + + Examples: + Combine DataFrames with different column orders: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1], "b": [10]}) + >>> df2 = ctx.from_pydict({"b": [20], "a": [2]}) + >>> df1.union_by_name(df2).sort("a").to_pydict() + {'a': [1, 2], 'b': [10, 20]} + + Union by name with deduplication: + + >>> df1 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + >>> df2 = ctx.from_pydict({"b": [10], "a": [1]}) + >>> df1.union_by_name(df2, distinct=True).to_pydict() + {'a': [1], 'b': [10]} """ - return DataFrame(self.df.except_all(other.df)) + return DataFrame(self.df.union_by_name(other.df, distinct)) + + def distinct_on( + self, + on_expr: list[Expr], + select_expr: list[Expr], + sort_expr: list[SortKey] | None = None, + ) -> DataFrame: + """Deduplicate rows based on specific columns. + + Returns a new DataFrame with one row per unique combination of the + ``on_expr`` columns, keeping the first row per group as determined by + ``sort_expr``. + + Args: + on_expr: Expressions that determine uniqueness. + select_expr: Expressions to include in the output. + sort_expr: Optional sort expressions to determine which row to keep. + + Returns: + DataFrame after deduplication. + + Examples: + Keep the row with the smallest ``b`` for each unique ``a``: + + >>> from datafusion import col + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2, 2], "b": [10, 20, 30, 40]}) + >>> df.distinct_on( + ... [col("a")], + ... [col("a"), col("b")], + ... [col("a").sort(ascending=True), col("b").sort(ascending=True)], + ... ).sort("a").to_pydict() + {'a': [1, 2], 'b': [10, 30]} + """ + on_raw = expr_list_to_raw_expr_list(on_expr) + select_raw = expr_list_to_raw_expr_list(select_expr) + sort_raw = sort_list_to_raw_sort_list(sort_expr) if sort_expr else None + return DataFrame(self.df.distinct_on(on_raw, select_raw, sort_raw)) + + def sort_by(self, *exprs: Expr | str) -> DataFrame: + """Sort the DataFrame by column expressions in ascending order. + + This is a convenience method that sorts the DataFrame by the given + expressions in ascending order with nulls last. For more control over + sort direction and null ordering, use :py:meth:`sort` instead. + + Args: + exprs: Expressions or column names to sort by. + + Returns: + DataFrame after sorting. + + Examples: + Sort by a single column: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [3, 1, 2]}) + >>> df.sort_by("a").to_pydict() + {'a': [1, 2, 3]} + """ + raw = [_to_raw_expr(e) for e in exprs] + return DataFrame(self.df.sort_by(raw)) def write_csv( self, @@ -1310,23 +1604,52 @@ def count(self) -> int: return self.df.count() @deprecated("Use :py:func:`unnest_columns` instead.") - def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: + def unnest_column( + self, + column: str, + preserve_nulls: bool = True, + ) -> DataFrame: """See :py:func:`unnest_columns`.""" return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) - def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFrame: + def unnest_columns( + self, + *columns: str, + preserve_nulls: bool = True, + recursions: list[tuple[str, str, int]] | None = None, + ) -> DataFrame: """Expand columns of arrays into a single row per array element. Args: columns: Column names to perform unnest operation on. preserve_nulls: If False, rows with null entries will not be returned. + recursions: Optional list of ``(input_column, output_column, depth)`` + tuples that control how deeply nested columns are unnested. Any + column not mentioned here is unnested with depth 1. Returns: A DataFrame with the columns expanded. + + Examples: + Unnest an array column: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2], [3]], "b": ["x", "y"]}) + >>> df.unnest_columns("a").to_pydict() + {'a': [1, 2, 3], 'b': ['x', 'x', 'y']} + + With explicit recursion depth: + + >>> df.unnest_columns("a", recursions=[("a", "a", 1)]).to_pydict() + {'a': [1, 2, 3], 'b': ['x', 'x', 'y']} """ columns = list(columns) - return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) + return DataFrame( + self.df.unnest_columns( + columns, preserve_nulls=preserve_nulls, recursions=recursions + ) + ) def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: """Export the DataFrame as an Arrow C Stream. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 759d6278c..bb8e9685c 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -29,6 +29,7 @@ import pytest from datafusion import ( DataFrame, + ExplainFormat, InsertOp, ParquetColumnOptions, ParquetWriterOptions, @@ -3569,3 +3570,263 @@ def test_read_parquet_file_sort_order(tmp_path, file_sort_order): pa.parquet.write_table(table, path) df = ctx.read_parquet(path, file_sort_order=file_sort_order) assert df.collect()[0].column(0).to_pylist() == [1, 2] + + +@pytest.mark.parametrize( + ("df1_data", "df2_data", "method", "kwargs", "expected_a", "expected_b"), + [ + pytest.param( + {"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]}, + {"a": [1, 2], "b": [10, 20]}, + "except_all", + {"distinct": True}, + [3], + [30], + id="except_all(distinct=True): removes matching rows and deduplicates", + ), + pytest.param( + {"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]}, + {"a": [1, 4], "b": [10, 40]}, + "intersect", + {"distinct": True}, + [1], + [10], + id="intersect(distinct=True): keeps common rows and deduplicates", + ), + pytest.param( + {"a": [1], "b": [10]}, + {"b": [20], "a": [2]}, # reversed column order tests matching by name + "union_by_name", + {}, + [1, 2], + [10, 20], + id="union_by_name: matches columns by name not position", + ), + ], +) +def test_set_operations_distinct( + df1_data, df2_data, method, kwargs, expected_a, expected_b +): + ctx = SessionContext() + df1 = ctx.from_pydict(df1_data) + df2 = ctx.from_pydict(df2_data) + result = ( + getattr(df1, method)(df2, **kwargs) + .sort(column("a").sort(ascending=True)) + .collect()[0] + ) + assert result.column(0).to_pylist() == expected_a + assert result.column(1).to_pylist() == expected_b + + +def test_union_by_name_distinct(): + ctx = SessionContext() + df1 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + df2 = ctx.from_pydict({"b": [10], "a": [1]}) + result = df1.union_by_name(df2, distinct=True).collect()[0] + assert result.column(0).to_pylist() == [1] + assert result.column(1).to_pylist() == [10] + + +def test_column_qualified(): + """DataFrame.column() returns a qualified column expression.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2], "b": [3, 4]}) + expr = df.column("a") + result = df.select(expr).collect()[0] + assert result.column(0).to_pylist() == [1, 2] + + +def test_column_not_found(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(Exception, match="not found"): + df.column("z") + + +def test_column_ambiguous(): + """After a join, duplicate column names that cannot be resolved raise an error.""" + ctx = SessionContext() + left = ctx.from_pydict({"id": [1, 2], "val": [10, 20]}) + right = ctx.from_pydict({"id": [1, 2], "val": [30, 40]}) + joined = left.join(right, on="id", how="inner") + with pytest.raises(Exception, match="not found"): + joined.column("val") + + +def test_column_after_join(): + """Qualified column works for non-ambiguous columns after a join.""" + ctx = SessionContext() + left = ctx.from_pydict({"id": [1, 2], "x": [10, 20]}) + right = ctx.from_pydict({"id": [1, 2], "y": [30, 40]}) + joined = left.join(right, on="id", how="inner") + expr = joined.column("y") + result = joined.select("id", expr).sort("id").collect()[0] + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [30, 40] + + +def test_col_join_disambiguate(): + """Use col() to disambiguate and select columns after a join.""" + ctx = SessionContext() + df1 = ctx.from_pydict({"foo": [1, 2, 3], "bar": [5, 6, 7]}) + df2 = ctx.from_pydict({"foo": [1, 2, 3], "baz": [8, 9, 10]}) + joined = df1.join_on(df2, df1.col("foo") == df2.col("foo"), how="inner") + result = ( + joined.select(df1.col("foo"), df1.col("bar"), df2.col("baz")) + .sort(df1.col("foo")) + .to_pydict() + ) + assert result["bar"] == [5, 6, 7] + assert result["baz"] == [8, 9, 10] + + +def test_find_qualified_columns(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + exprs = df.find_qualified_columns("a", "c") + assert len(exprs) == 2 + result = df.select(*exprs).collect()[0] + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [5, 6] + + +def test_find_qualified_columns_not_found(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(Exception, match="not found"): + df.find_qualified_columns("a", "z") + + +def test_distinct_on(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2, 2], "b": [10, 20, 30, 40]}) + result = ( + df.distinct_on( + [column("a")], + [column("a"), column("b")], + [column("a").sort(ascending=True), column("b").sort(ascending=True)], + ) + .sort(column("a").sort(ascending=True)) + .collect()[0] + ) + # Keeps the first row per group (smallest b per a) + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [10, 30] + + +@pytest.mark.parametrize( + ("input_values", "expected"), + [ + ([3, 1, 2], [1, 2, 3]), + ([1, 2, 3], [1, 2, 3]), + ([3, None, 1, 2], [1, 2, 3, None]), + ], +) +def test_sort_by(input_values, expected): + """sort_by always sorts ascending with nulls last regardless of input order.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": input_values}) + result = df.sort_by(column("a")).collect()[0] + assert result.column(0).to_pylist() == expected + + +@pytest.mark.parametrize( + ("fmt", "verbose", "analyze", "expected_substring"), + [ + pytest.param(None, False, False, None, id="default format"), + pytest.param(ExplainFormat.TREE, False, False, "---", id="tree format"), + pytest.param( + ExplainFormat.INDENT, True, True, None, id="indent verbose+analyze" + ), + pytest.param(ExplainFormat.PGJSON, False, False, '"Plan"', id="pgjson format"), + pytest.param( + ExplainFormat.GRAPHVIZ, False, False, "digraph", id="graphviz format" + ), + ], +) +def test_explain_with_format(capsys, fmt, verbose, analyze, expected_substring): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + df.explain(verbose=verbose, analyze=analyze, format=fmt) + captured = capsys.readouterr() + assert "plan_type" in captured.out + if expected_substring is not None: + assert expected_substring in captured.out + + +@pytest.mark.parametrize( + ("window_exprs", "expected_columns"), + [ + pytest.param( + lambda: [ + f.row_number(partition_by=[column("b")], order_by=[column("a")]).alias( + "rn" + ), + ], + {"rn": [1, 2, 1]}, + id="single window expression", + ), + pytest.param( + lambda: [ + f.row_number(partition_by=[column("b")], order_by=[column("a")]).alias( + "rn" + ), + f.rank(partition_by=[column("b")], order_by=[column("a")]).alias("rnk"), + ], + {"rn": [1, 2, 1], "rnk": [1, 2, 1]}, + id="multiple window expressions", + ), + ], +) +def test_window(window_exprs, expected_columns): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "x", "y"]}) + result = ( + df.window(*window_exprs()).sort(column("a").sort(ascending=True)).collect()[0] + ) + for col_name, expected_values in expected_columns.items(): + assert col_name in result.schema.names + assert ( + result.column(result.schema.get_field_index(col_name)).to_pylist() + == expected_values + ) + + +@pytest.mark.parametrize( + ("input_data", "recursions", "expected_a"), + [ + pytest.param( + {"a": [[1, 2], [3]], "b": ["x", "y"]}, + None, + [1, 2, 3], + id="basic unnest without recursions", + ), + pytest.param( + {"a": [[1, 2], [3]], "b": ["x", "y"]}, + [("a", "a", 1)], + [1, 2, 3], + id="explicit depth 1 matches basic unnest", + ), + pytest.param( + {"a": [[[1, 2], [3]], [[4]]], "b": ["x", "y"]}, + [("a", "a", 1)], + [[1, 2], [3], [4]], + id="depth 1 on nested lists keeps inner lists", + ), + pytest.param( + {"a": [[[1, 2], [3]], [[4]]], "b": ["x", "y"]}, + [("a", "a", 2)], + [1, 2, 3, 4], + id="depth 2 fully flattens nested lists", + ), + ], +) +def test_unnest_columns_with_recursions(input_data, recursions, expected_a): + ctx = SessionContext() + df = ctx.from_pydict(input_data) + kwargs = {} + if recursions is not None: + kwargs["recursions"] = recursions + result = df.unnest_columns("a", **kwargs).collect()[0] + assert result.column(0).to_pylist() == expected_a