Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions duckdb/experimental/spark/sql/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import operator
import warnings
from collections.abc import Callable
from functools import reduce
from typing import TYPE_CHECKING, Any, Optional, Union, overload

from duckdb import (
Expand Down Expand Up @@ -6208,6 +6210,164 @@ def expr(str: str) -> Column:
return Column(SQLExpression(str))


def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column:
"""Aggregate function: returns the number of distinct rows considering the given columns.

Rows where any of the supplied columns is NULL are excluded from the count,
matching Spark / standard SQL `COUNT(DISTINCT col1, col2, ...)` semantics.

.. versionadded:: 1.3.0

Examples:
--------
>>> df = spark.createDataFrame([(1,), (1,), (2,), (None,)], ["v"])
>>> df.select(count_distinct(df.v).alias("d")).collect()
[Row(d=2)]

>>> df = spark.createDataFrame(
... [(1, "a"), (1, "a"), (1, "b"), (None, "c"), (2, None)], ["a", "b"]
... )
>>> df.select(count_distinct("a", "b").alias("d")).collect()
[Row(d=2)]
"""
exprs = [_to_column_expr(c) for c in (col, *cols)]
if len(exprs) == 1:
arg = exprs[0]
else:
any_null = reduce(operator.or_, (e.isnull() for e in exprs))
arg = CaseExpression(any_null, ConstantExpression(None)).otherwise(FunctionExpression("struct_pack", *exprs))
return _invoke_function(
"array_length",
FunctionExpression(
"array_distinct",
FunctionExpression("list", arg),
),
)


def countDistinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column:
"""Alias of :func:`count_distinct`."""
return count_distinct(col, *cols)


def collect_set(col: "ColumnOrName") -> Column:
"""Aggregate function: returns a set of objects with duplicate elements eliminated.

NULL values are excluded. The order of elements is non-deterministic.

.. versionadded:: 1.6.0

Examples:
--------
>>> df = spark.createDataFrame([(1,), (1,), (2,)], ["v"])
>>> sorted(df.select(collect_set("v")).first()[0])
[1, 2]
"""
return _invoke_function(
"array_distinct",
FunctionExpression("list", _to_column_expr(col)),
)


def count_if(col: "ColumnOrName") -> Column:
"""Aggregate function: returns the number of `TRUE` values for the expression.

.. versionadded:: 3.5.0

Examples:
--------
>>> df = spark.createDataFrame([(1,), (2,), (3,)], ["v"])
>>> df.select(count_if(df.v > 1).alias("c")).collect()
[Row(c=2)]
"""
return _invoke_function_over_columns("count_if", col)


def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
"""Returns the value associated with the maximum value of `ord`.

.. versionadded:: 3.3.0

Examples:
--------
>>> df = spark.createDataFrame([("a", 1), ("b", 3), ("c", 2)], ["k", "v"])
>>> df.select(max_by("k", "v")).first()[0]
'b'
"""
return _invoke_function("arg_max", _to_column_expr(col), _to_column_expr(ord))


def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
"""Returns the value associated with the minimum value of `ord`.

.. versionadded:: 3.3.0

Examples:
--------
>>> df = spark.createDataFrame([("a", 1), ("b", 3), ("c", 2)], ["k", "v"])
>>> df.select(min_by("k", "v")).first()[0]
'a'
"""
return _invoke_function("arg_min", _to_column_expr(col), _to_column_expr(ord))


def bool_and(col: "ColumnOrName") -> Column:
"""Aggregate function: returns true if all values of `col` are true.

.. versionadded:: 3.5.0

Examples:
--------
>>> df = spark.createDataFrame([(True,), (True,), (False,)], ["b"])
>>> df.select(bool_and("b")).first()[0]
False
"""
return _invoke_function_over_columns("bool_and", col)


def every(col: "ColumnOrName") -> Column:
"""Alias of :func:`bool_and`."""
return bool_and(col)


def bool_or(col: "ColumnOrName") -> Column:
"""Aggregate function: returns true if at least one value of `col` is true.

.. versionadded:: 3.5.0

Examples:
--------
>>> df = spark.createDataFrame([(True,), (True,), (False,)], ["b"])
>>> df.select(bool_or("b")).first()[0]
True
"""
return _invoke_function_over_columns("bool_or", col)


def some(col: "ColumnOrName") -> Column:
"""Alias of :func:`bool_or`."""
return bool_or(col)


def any(col: "ColumnOrName") -> Column:
"""Alias of :func:`bool_or`."""
return bool_or(col)


def kurtosis(col: "ColumnOrName") -> Column:
"""Aggregate function: returns the kurtosis of the values in a group.

.. versionadded:: 1.6.0

Examples:
--------
>>> df = spark.createDataFrame([(1.0,), (2.0,), (3.0,), (4.0,)], ["v"])
>>> df.select(kurtosis("v")).first()[0] is not None
True
"""
return _invoke_function_over_columns("kurtosis", col)


def broadcast(df: "DataFrame") -> "DataFrame":
"""The broadcast function in Spark is used to optimize joins by broadcasting a smaller
dataset to all the worker nodes. However, DuckDB operates on a single-node architecture .
Expand Down
93 changes: 93 additions & 0 deletions tests/fast/spark/test_spark_functions_aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest

_ = pytest.importorskip("duckdb.experimental.spark")

from spark_namespace.sql import functions as F
from spark_namespace.sql.types import Row


class TestSparkAggregateFunctions:
def test_count_distinct(self, spark):
df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2), ("g", None)], ["k", "v"])
res = df.groupBy("k").agg(F.count_distinct("v").alias("d")).collect()
assert res == [Row(k="g", d=2)]

def test_countDistinct_alias(self, spark):
df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2)], ["k", "v"])
res = df.groupBy("k").agg(F.countDistinct("v").alias("d")).collect()
assert res == [Row(k="g", d=2)]

def test_count_distinct_multi_col(self, spark):
df = spark.createDataFrame(
[
("g", 1, "a"),
("g", 1, "a"),
("g", 1, "b"),
("g", None, "c"),
("g", 2, None),
("g", None, None),
],
["k", "a", "b"],
)
res = df.groupBy("k").agg(F.count_distinct("a", "b").alias("d")).collect()
assert res == [Row(k="g", d=2)]

def test_collect_set(self, spark):
df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2), ("g", None)], ["k", "v"])
row = df.groupBy("k").agg(F.collect_set("v").alias("s")).collect()[0]
assert row.k == "g"
assert sorted(row.s) == [1, 2]

def test_count_if(self, spark):
df = spark.createDataFrame([("g", 1), ("g", 2), ("g", 3)], ["k", "v"])
res = df.groupBy("k").agg(F.count_if(F.col("v") > 1).alias("c")).collect()
assert res == [Row(k="g", c=2)]

def test_max_by(self, spark):
df = spark.createDataFrame([("g", "a", 1), ("g", "b", 3), ("g", "c", 2)], ["k", "name", "v"])
res = df.groupBy("k").agg(F.max_by("name", "v").alias("m")).collect()
assert res == [Row(k="g", m="b")]

def test_min_by(self, spark):
df = spark.createDataFrame([("g", "a", 1), ("g", "b", 3), ("g", "c", 2)], ["k", "name", "v"])
res = df.groupBy("k").agg(F.min_by("name", "v").alias("m")).collect()
assert res == [Row(k="g", m="a")]

def test_bool_and(self, spark):
df = spark.createDataFrame([("g", True), ("g", True), ("g", False)], ["k", "b"])
res = df.groupBy("k").agg(F.bool_and("b").alias("r")).collect()
assert res == [Row(k="g", r=False)]

df2 = spark.createDataFrame([("g", True), ("g", True), ("g", True)], ["k", "b"])
res2 = df2.groupBy("k").agg(F.bool_and("b").alias("r")).collect()
assert res2 == [Row(k="g", r=True)]

def test_every_alias(self, spark):
df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"])
res = df.groupBy("k").agg(F.every("b").alias("r")).collect()
assert res == [Row(k="g", r=False)]

def test_bool_or(self, spark):
df = spark.createDataFrame([("g", True), ("g", False), ("g", False)], ["k", "b"])
res = df.groupBy("k").agg(F.bool_or("b").alias("r")).collect()
assert res == [Row(k="g", r=True)]

df2 = spark.createDataFrame([("g", False), ("g", False)], ["k", "b"])
res2 = df2.groupBy("k").agg(F.bool_or("b").alias("r")).collect()
assert res2 == [Row(k="g", r=False)]

def test_some_alias(self, spark):
df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"])
res = df.groupBy("k").agg(F.some("b").alias("r")).collect()
assert res == [Row(k="g", r=True)]

def test_any_alias(self, spark):
df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"])
res = df.groupBy("k").agg(F.any("b").alias("r")).collect()
assert res == [Row(k="g", r=True)]

def test_kurtosis(self, spark):
df = spark.createDataFrame([("g", 1.0), ("g", 2.0), ("g", 3.0), ("g", 4.0)], ["k", "v"])
row = df.groupBy("k").agg(F.kurtosis("v").alias("kur")).collect()[0]
assert row.k == "g"
assert row.kur is not None