diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 71ff8c59..7e2b8d0d 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1,5 +1,7 @@ +import operator import warnings from collections.abc import Callable +from functools import reduce from typing import TYPE_CHECKING, Any, Optional, Union, overload from duckdb import ( @@ -6208,6 +6210,164 @@ def expr(str: str) -> Column: return Column(SQLExpression(str)) +def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: + """Aggregate function: returns the number of distinct rows considering the given columns. + + Rows where any of the supplied columns is NULL are excluded from the count, + matching Spark / standard SQL `COUNT(DISTINCT col1, col2, ...)` semantics. + + .. versionadded:: 1.3.0 + + Examples: + -------- + >>> df = spark.createDataFrame([(1,), (1,), (2,), (None,)], ["v"]) + >>> df.select(count_distinct(df.v).alias("d")).collect() + [Row(d=2)] + + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (1, "b"), (None, "c"), (2, None)], ["a", "b"] + ... ) + >>> df.select(count_distinct("a", "b").alias("d")).collect() + [Row(d=2)] + """ + exprs = [_to_column_expr(c) for c in (col, *cols)] + if len(exprs) == 1: + arg = exprs[0] + else: + any_null = reduce(operator.or_, (e.isnull() for e in exprs)) + arg = CaseExpression(any_null, ConstantExpression(None)).otherwise(FunctionExpression("struct_pack", *exprs)) + return _invoke_function( + "array_length", + FunctionExpression( + "array_distinct", + FunctionExpression("list", arg), + ), + ) + + +def countDistinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: + """Alias of :func:`count_distinct`.""" + return count_distinct(col, *cols) + + +def collect_set(col: "ColumnOrName") -> Column: + """Aggregate function: returns a set of objects with duplicate elements eliminated. + + NULL values are excluded. The order of elements is non-deterministic. + + .. versionadded:: 1.6.0 + + Examples: + -------- + >>> df = spark.createDataFrame([(1,), (1,), (2,)], ["v"]) + >>> sorted(df.select(collect_set("v")).first()[0]) + [1, 2] + """ + return _invoke_function( + "array_distinct", + FunctionExpression("list", _to_column_expr(col)), + ) + + +def count_if(col: "ColumnOrName") -> Column: + """Aggregate function: returns the number of `TRUE` values for the expression. + + .. versionadded:: 3.5.0 + + Examples: + -------- + >>> df = spark.createDataFrame([(1,), (2,), (3,)], ["v"]) + >>> df.select(count_if(df.v > 1).alias("c")).collect() + [Row(c=2)] + """ + return _invoke_function_over_columns("count_if", col) + + +def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: + """Returns the value associated with the maximum value of `ord`. + + .. versionadded:: 3.3.0 + + Examples: + -------- + >>> df = spark.createDataFrame([("a", 1), ("b", 3), ("c", 2)], ["k", "v"]) + >>> df.select(max_by("k", "v")).first()[0] + 'b' + """ + return _invoke_function("arg_max", _to_column_expr(col), _to_column_expr(ord)) + + +def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: + """Returns the value associated with the minimum value of `ord`. + + .. versionadded:: 3.3.0 + + Examples: + -------- + >>> df = spark.createDataFrame([("a", 1), ("b", 3), ("c", 2)], ["k", "v"]) + >>> df.select(min_by("k", "v")).first()[0] + 'a' + """ + return _invoke_function("arg_min", _to_column_expr(col), _to_column_expr(ord)) + + +def bool_and(col: "ColumnOrName") -> Column: + """Aggregate function: returns true if all values of `col` are true. + + .. versionadded:: 3.5.0 + + Examples: + -------- + >>> df = spark.createDataFrame([(True,), (True,), (False,)], ["b"]) + >>> df.select(bool_and("b")).first()[0] + False + """ + return _invoke_function_over_columns("bool_and", col) + + +def every(col: "ColumnOrName") -> Column: + """Alias of :func:`bool_and`.""" + return bool_and(col) + + +def bool_or(col: "ColumnOrName") -> Column: + """Aggregate function: returns true if at least one value of `col` is true. + + .. versionadded:: 3.5.0 + + Examples: + -------- + >>> df = spark.createDataFrame([(True,), (True,), (False,)], ["b"]) + >>> df.select(bool_or("b")).first()[0] + True + """ + return _invoke_function_over_columns("bool_or", col) + + +def some(col: "ColumnOrName") -> Column: + """Alias of :func:`bool_or`.""" + return bool_or(col) + + +def any(col: "ColumnOrName") -> Column: + """Alias of :func:`bool_or`.""" + return bool_or(col) + + +def kurtosis(col: "ColumnOrName") -> Column: + """Aggregate function: returns the kurtosis of the values in a group. + + .. versionadded:: 1.6.0 + + Examples: + -------- + >>> df = spark.createDataFrame([(1.0,), (2.0,), (3.0,), (4.0,)], ["v"]) + >>> df.select(kurtosis("v")).first()[0] is not None + True + """ + return _invoke_function_over_columns("kurtosis", col) + + def broadcast(df: "DataFrame") -> "DataFrame": """The broadcast function in Spark is used to optimize joins by broadcasting a smaller dataset to all the worker nodes. However, DuckDB operates on a single-node architecture . diff --git a/tests/fast/spark/test_spark_functions_aggregate.py b/tests/fast/spark/test_spark_functions_aggregate.py new file mode 100644 index 00000000..78e40f92 --- /dev/null +++ b/tests/fast/spark/test_spark_functions_aggregate.py @@ -0,0 +1,93 @@ +import pytest + +_ = pytest.importorskip("duckdb.experimental.spark") + +from spark_namespace.sql import functions as F +from spark_namespace.sql.types import Row + + +class TestSparkAggregateFunctions: + def test_count_distinct(self, spark): + df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2), ("g", None)], ["k", "v"]) + res = df.groupBy("k").agg(F.count_distinct("v").alias("d")).collect() + assert res == [Row(k="g", d=2)] + + def test_countDistinct_alias(self, spark): + df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2)], ["k", "v"]) + res = df.groupBy("k").agg(F.countDistinct("v").alias("d")).collect() + assert res == [Row(k="g", d=2)] + + def test_count_distinct_multi_col(self, spark): + df = spark.createDataFrame( + [ + ("g", 1, "a"), + ("g", 1, "a"), + ("g", 1, "b"), + ("g", None, "c"), + ("g", 2, None), + ("g", None, None), + ], + ["k", "a", "b"], + ) + res = df.groupBy("k").agg(F.count_distinct("a", "b").alias("d")).collect() + assert res == [Row(k="g", d=2)] + + def test_collect_set(self, spark): + df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2), ("g", None)], ["k", "v"]) + row = df.groupBy("k").agg(F.collect_set("v").alias("s")).collect()[0] + assert row.k == "g" + assert sorted(row.s) == [1, 2] + + def test_count_if(self, spark): + df = spark.createDataFrame([("g", 1), ("g", 2), ("g", 3)], ["k", "v"]) + res = df.groupBy("k").agg(F.count_if(F.col("v") > 1).alias("c")).collect() + assert res == [Row(k="g", c=2)] + + def test_max_by(self, spark): + df = spark.createDataFrame([("g", "a", 1), ("g", "b", 3), ("g", "c", 2)], ["k", "name", "v"]) + res = df.groupBy("k").agg(F.max_by("name", "v").alias("m")).collect() + assert res == [Row(k="g", m="b")] + + def test_min_by(self, spark): + df = spark.createDataFrame([("g", "a", 1), ("g", "b", 3), ("g", "c", 2)], ["k", "name", "v"]) + res = df.groupBy("k").agg(F.min_by("name", "v").alias("m")).collect() + assert res == [Row(k="g", m="a")] + + def test_bool_and(self, spark): + df = spark.createDataFrame([("g", True), ("g", True), ("g", False)], ["k", "b"]) + res = df.groupBy("k").agg(F.bool_and("b").alias("r")).collect() + assert res == [Row(k="g", r=False)] + + df2 = spark.createDataFrame([("g", True), ("g", True), ("g", True)], ["k", "b"]) + res2 = df2.groupBy("k").agg(F.bool_and("b").alias("r")).collect() + assert res2 == [Row(k="g", r=True)] + + def test_every_alias(self, spark): + df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"]) + res = df.groupBy("k").agg(F.every("b").alias("r")).collect() + assert res == [Row(k="g", r=False)] + + def test_bool_or(self, spark): + df = spark.createDataFrame([("g", True), ("g", False), ("g", False)], ["k", "b"]) + res = df.groupBy("k").agg(F.bool_or("b").alias("r")).collect() + assert res == [Row(k="g", r=True)] + + df2 = spark.createDataFrame([("g", False), ("g", False)], ["k", "b"]) + res2 = df2.groupBy("k").agg(F.bool_or("b").alias("r")).collect() + assert res2 == [Row(k="g", r=False)] + + def test_some_alias(self, spark): + df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"]) + res = df.groupBy("k").agg(F.some("b").alias("r")).collect() + assert res == [Row(k="g", r=True)] + + def test_any_alias(self, spark): + df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"]) + res = df.groupBy("k").agg(F.any("b").alias("r")).collect() + assert res == [Row(k="g", r=True)] + + def test_kurtosis(self, spark): + df = spark.createDataFrame([("g", 1.0), ("g", 2.0), ("g", 3.0), ("g", 4.0)], ["k", "v"]) + row = df.groupBy("k").agg(F.kurtosis("v").alias("kur")).collect()[0] + assert row.k == "g" + assert row.kur is not None