diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 629af98..5c933ef 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -20,7 +20,7 @@ from collections.abc import Iterator from enum import IntEnum from types import EllipsisType, ModuleType -from typing import Any, Final, Literal, SupportsIndex, Callable +from typing import Any, Literal, SupportsIndex, Callable import numpy as np import numpy.typing as npt @@ -40,35 +40,11 @@ _real_to_complex_map, _result_type, ) +from ._devices import CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, set_array_api_strict_flags from ._typing import PyCapsule -class Device: - _device: Final[str] - __slots__ = ("_device", "__weakref__") - - def __init__(self, device: str = "CPU_DEVICE"): - if device not in ("CPU_DEVICE", "device1", "device2"): - raise ValueError(f"The device '{device}' is not a valid choice.") - self._device = device - - def __repr__(self) -> str: - return f"array_api_strict.Device('{self._device}')" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Device): - return False - return self._device == other._device - - def __hash__(self) -> int: - return hash(("Device", self._device)) - - -CPU_DEVICE = Device() -ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2")) - - class Array: """ n-d array object for the array API namespace. diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 8d3dc60..b6e1d67 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -5,7 +5,11 @@ import numpy as np -from ._dtypes import DType, _all_dtypes, _np_dtype +from ._dtypes import DType, _all_dtypes, _np_dtype, bool as xp_bool +from ._devices import ( + Device, device_supports_dtype, get_default_dtypes, + check_device as _check_device +) from ._flags import get_array_api_strict_flags from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack @@ -14,7 +18,7 @@ from typing_extensions import TypeIs # Circular import - from ._array_object import Array, Device + from ._array_object import Array class Undef(Enum): @@ -24,10 +28,15 @@ class Undef(Enum): _undef = Undef.UNDEF -def _check_valid_dtype(dtype: DType | None) -> None: +def _check_valid_dtype(dtype: DType | None, device: Device | None = None) -> None: # Note: Only spelling dtypes as the dtype objects is supported. - if dtype not in (None,) + _all_dtypes: - raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") + if dtype is not None: + if dtype not in _all_dtypes: + raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") + + if device is not None: + if not device_supports_dtype(device, dtype): + raise ValueError(f"Device {device!r} does not support dtype={dtype!r}.") def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]: @@ -38,18 +47,6 @@ def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]: return True -def _check_device(device: Device | None) -> None: - # _array_object imports in this file are inside the functions to avoid - # circular imports - from ._array_object import ALL_DEVICES, Device - - if device is not None and not isinstance(device, Device): - raise ValueError(f"Unsupported device {device!r}") - - if device is not None and device not in ALL_DEVICES: - raise ValueError(f"Unsupported device {device!r}") - - def asarray( obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, @@ -65,11 +62,12 @@ def asarray( """ from ._array_object import Array - _check_valid_dtype(dtype) + _check_device(device) + _check_valid_dtype(dtype, device) _np_dtype = None if dtype is not None: _np_dtype = dtype._np_dtype - _check_device(device) + if isinstance(obj, Array) and device is None: device = obj.device @@ -127,8 +125,13 @@ def arange( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) + if dtype is None: + if any(isinstance(x, float) for x in (start, stop, step)): + dtype = get_default_dtypes(device)["real floating"] + else: + dtype = get_default_dtypes(device)["integral"] return Array._new( np.arange(start, stop, step, dtype=_np_dtype(dtype)), @@ -149,8 +152,10 @@ def empty( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device) @@ -165,10 +170,12 @@ def empty_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype + _check_valid_dtype(dtype, device) return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device) @@ -189,8 +196,10 @@ def eye( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new( np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device @@ -237,12 +246,22 @@ def full( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) if not isinstance(fill_value, bool | int | float | complex): msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" raise TypeError(msg) + + if dtype is None: + if type(fill_value) == bool: + dtype = xp_bool + else: + kind = { + int: "integral", float: "real floating", complex: "complex floating" + }[type(fill_value)] + dtype = get_default_dtypes(device)[kind] + res = np.full(shape, fill_value, dtype=_np_dtype(dtype)) if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy @@ -266,10 +285,12 @@ def full_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype + _check_valid_dtype(dtype, device) if not isinstance(fill_value, bool | int | float | complex): msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" @@ -300,8 +321,13 @@ def linspace( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) + if dtype is None: + if isinstance(start, complex) or isinstance(stop, complex): + dtype = get_default_dtypes(device)["complex floating"] + else: + dtype = get_default_dtypes(device)["real floating"] return Array._new( np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint), @@ -353,8 +379,10 @@ def ones( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device) @@ -369,10 +397,12 @@ def ones_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype + _check_valid_dtype(dtype, device) return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device) @@ -418,8 +448,10 @@ def zeros( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device) @@ -434,9 +466,11 @@ def zeros_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype + _check_valid_dtype(dtype, device) return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device) diff --git a/array_api_strict/_devices.py b/array_api_strict/_devices.py new file mode 100644 index 0000000..cd76eae --- /dev/null +++ b/array_api_strict/_devices.py @@ -0,0 +1,101 @@ +from typing import Final + +from ._dtypes import ( + DType, float32, float64, complex64, complex128, int64, + _all_dtypes, _boolean_dtypes, _signed_integer_dtypes, + _unsigned_integer_dtypes, _integer_dtypes, _real_floating_dtypes, + _complex_floating_dtypes, _numeric_dtypes +) + +_ALL_DEVICE_NAMES = ("CPU_DEVICE", "device1", "device2", "F32_device") + +class Device: + _device: Final[str] + __slots__ = ("_device", "__weakref__") + + def __init__(self, device: str = "CPU_DEVICE"): + if device not in _ALL_DEVICE_NAMES: + raise ValueError(f"The device '{device}' is not a valid choice.") + self._device = device + + def __repr__(self) -> str: + return f"array_api_strict.Device('{self._device}')" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Device): + return False + return self._device == other._device + + def __hash__(self) -> int: + return hash(("Device", self._device)) + + def _supported_dtypes(self) -> list[DType]: + # XXX useful? Unused ATM + return list(dt for dt in _all_dtypes if device_supports_dtype(self, dt)) + + +CPU_DEVICE = Device() +_F32_DEVICE = Device("F32_device") + +ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"), _F32_DEVICE) + + +def check_device(device: Device | None) -> None: + if device is not None and not isinstance(device, Device): + raise ValueError(f"Unsupported device {device!r}") + + if device is not None and device not in ALL_DEVICES: + raise ValueError(f"Unsupported device {device!r}") + + +# Helpers for device-specific dtype support + +def get_default_dtypes(device: Device | None = None) -> dict[str, Device]: + if device == _F32_DEVICE: + return { + "real floating": float32, + "complex floating": complex64, + "integral": int64, + "indexing": int64, + } + else: + return { + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, + } + + +def device_supports_dtype(device: Device | None, dtype: DType |None) -> bool: + """True if `device` supports `dtype`, False otherwise.""" + # special-case F32_device + if device == _F32_DEVICE: + return dtype not in (float64, complex128) + + # All other devices support all dtypes + return True + + +def _map_supported(dtypes: list[DType], device: Device) -> dict[str, DType]: + return { + dt._canonic_name: dt + for dt in dtypes + if device_supports_dtype(device, dt) + } + + +# _info.dtypes() maps "kind" -> dict of {name: dtype} +# Note that "kinds" differ from "categories" above, per the spec. + +_kind_to_dtypes = { + None: _all_dtypes, + "bool": _boolean_dtypes, + "signed integer": _signed_integer_dtypes, + "unsigned integer": _unsigned_integer_dtypes, + "integral": _integer_dtypes, + "real floating": _real_floating_dtypes, + "complex floating": _complex_floating_dtypes, + "numeric": _numeric_dtypes +} + diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index 564db5a..2d5ec33 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -11,9 +11,11 @@ class DType: _np_dtype: Final[np.dtype[Any]] - __slots__ = ("_np_dtype", "__weakref__") + _canonic_name: Final[Any] + __slots__ = ("_np_dtype", "_canonic_name", "__weakref__") def __init__(self, np_dtype: npt.DTypeLike): + self._canonic_name = np_dtype self._np_dtype = np.dtype(np_dtype) def __repr__(self) -> str: diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index c2c617e..9f0dfcf 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -3,7 +3,8 @@ import numpy as np -from ._array_object import ALL_DEVICES, Array, Device +from ._array_object import Array +from ._devices import ALL_DEVICES, Device from ._data_type_functions import astype from ._dtypes import ( DType, diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 12beed0..cbee036 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,7 +1,7 @@ import numpy as np -from . import _dtypes as dt -from ._array_object import ALL_DEVICES, CPU_DEVICE, Device +from . import _devices +from ._devices import ALL_DEVICES, CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, requires_api_version from ._typing import Capabilities, DataTypes, DefaultDataTypes @@ -40,12 +40,7 @@ def default_dtypes( *, device: Device | None = None, ) -> DefaultDataTypes: - return { - "real floating": dt.float64, - "complex floating": dt.complex128, - "integral": dt.int64, - "indexing": dt.int64, - } + return _devices.get_default_dtypes(device) @requires_api_version('2023.12') def dtypes( @@ -54,78 +49,21 @@ def dtypes( device: Device | None = None, kind: str | tuple[str, ...] | None = None, ) -> DataTypes: - if kind is None: - return { - "bool": dt.bool, - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - "float32": dt.float32, - "float64": dt.float64, - "complex64": dt.complex64, - "complex128": dt.complex128, - } - if kind == "bool": - return {"bool": dt.bool} - if kind == "signed integer": - return { - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - } - if kind == "unsigned integer": - return { - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - } - if kind == "integral": - return { - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - } - if kind == "real floating": - return { - "float32": dt.float32, - "float64": dt.float64, - } - if kind == "complex floating": - return { - "complex64": dt.complex64, - "complex128": dt.complex128, - } - if kind == "numeric": - return { - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - "float32": dt.float32, - "float64": dt.float64, - "complex64": dt.complex64, - "complex128": dt.complex128, - } - if isinstance(kind, tuple): + if device is None: + device = CPU_DEVICE + if isinstance(kind, type(None) | str): + + try: + dtypes = _devices._kind_to_dtypes[kind] + except KeyError: + raise ValueError(f"unsupported kind: {kind!r}") + res = _devices._map_supported(dtypes, device) + return res + + elif isinstance(kind, tuple): res: DataTypes = {} for k in kind: - res.update(self.dtypes(kind=k)) + res.update(self.dtypes(kind=k, device=device)) return res raise ValueError(f"unsupported kind: {kind!r}") diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 6736826..b8b71be 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -23,7 +23,9 @@ zeros_like, ) from .._dtypes import float32, float64 -from .._array_object import Array, CPU_DEVICE, Device +from .._array_object import Array +from .._devices import CPU_DEVICE, ALL_DEVICES, Device +from .._info import __array_namespace_info__ from .._flags import set_array_api_strict_flags def test_asarray_errors(): @@ -212,6 +214,7 @@ def test_zeros_like_errors(): assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int)) assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i")) + def test_meshgrid_dtype_errors(): # Doesn't raise meshgrid() @@ -221,6 +224,94 @@ def test_meshgrid_dtype_errors(): assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64))) + +def _full(a, *args, **kwds): + return full(a, fill_value=42.0, *args, **kwds) + + +def _full_like(a, *args, **kwds): + return full_like(a, fill_value=42.0, *args, **kwds) + + +class TestDefaultDType: + + info = __array_namespace_info__() + + @pytest.mark.parametrize("device", ALL_DEVICES) + @pytest.mark.parametrize("func", [empty, zeros, ones, _full]) + def test_ones_etc(self, func, device): + a = func(1, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + @pytest.mark.parametrize("func", [empty_like, zeros_like, ones_like, _full_like]) + def test_ones_like_etc_correct(self, func): + # float32 is preserved + a = ones(2, dtype=float32) + device = Device('F32_device') + b = func(a, device=device) + assert b.dtype == self.info.default_dtypes(device=device)["real floating"] + + @pytest.mark.parametrize("func", [empty_like, zeros_like, ones_like, _full_like]) + def test_ones_like_etc_incorrect(self, func): + a = ones(2) + assert a.dtype == float64 + assert a.device == Device() + + # XXX: a.dtype not supported by the device: ValueError or TypeError? + + # >>> a = torch.ones(3, dtype=torch.float64, device='cpu') + # >>> torch.ones_like(a, device='mps') + # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework + # doesn't support float64. + + # incompatible dtype inferred from `a.dtype` + with pytest.raises((TypeError, ValueError)): + func(a, device=Device('F32_device')) + + # `a.dtype` is compatible but the explicit dtype= argument is incompatible + a = ones(2, dtype=float32) + with pytest.raises((TypeError, ValueError)): + func(a, device=Device('F32_device'), dtype=float64) + + def test_eye(self): + device = Device('F32_device') + a = eye(3, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + with pytest.raises((TypeError, ValueError)): + eye(3, device=device, dtype=float64) + + def test_linspace(self): + device = Device('F32_device') + + a = linspace(1, 10, 11, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + a = linspace(1+0j, 10, 11, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["complex floating"] + + with pytest.raises((TypeError, ValueError)): + linspace(1, 10, 11, device=device, dtype=float64) + + def test_arange(self): + device = Device('F32_device') + + a = arange(0, 10, 1, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["integral"] + + a = arange(0.0, 10, 1, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + with pytest.raises((TypeError, ValueError)): + arange(0, 10, 1, device=device, dtype=float64) + + with pytest.raises((TypeError, ValueError)): + arange(0.0, 10, 1, device=device, dtype=float64) + +# TODO: +# def asarray( + + @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def from_dlpack_2023_12(api_version): if api_version != '2022.12': diff --git a/array_api_strict/tests/test_device_support.py b/array_api_strict/tests/test_device_support.py index 0f3d6b5..242151a 100644 --- a/array_api_strict/tests/test_device_support.py +++ b/array_api_strict/tests/test_device_support.py @@ -1,6 +1,6 @@ import pytest -import array_api_strict +import array_api_strict as xp @pytest.mark.parametrize( @@ -18,11 +18,11 @@ ), ) def test_fft_device_support_complex(func_name): - func = getattr(array_api_strict.fft, func_name) - x = array_api_strict.asarray( + func = getattr(xp.fft, func_name) + x = xp.asarray( [1, 2.0], - dtype=array_api_strict.complex64, - device=array_api_strict.Device("device1"), + dtype=xp.complex64, + device=xp.Device("device1"), ) y = func(x) @@ -31,8 +31,38 @@ def test_fft_device_support_complex(func_name): @pytest.mark.parametrize("func_name", ("rfft", "rfftn", "ihfft")) def test_fft_device_support_real(func_name): - func = getattr(array_api_strict.fft, func_name) - x = array_api_strict.asarray([1, 2.0], device=array_api_strict.Device("device1")) + func = getattr(xp.fft, func_name) + x = xp.asarray([1, 2.0], device=xp.Device("device1")) y = func(x) assert x.device == y.device + + +class TestF32Device: + @pytest.mark.parametrize("dtype_str", ["float64", "complex128"]) + def test_f64_raises(self, dtype_str): + f32_device = xp.Device("F32_device") + dtype = getattr(xp, dtype_str) + with pytest.raises(ValueError): + xp.arange(3, device=f32_device, dtype=dtype) + + def test_info_no_f64(self): + f32_device = xp.Device("F32_device") + + info = xp.__array_namespace_info__() + all_dtypes = info.dtypes(device=f32_device) + assert "float64" not in all_dtypes + assert "complex128" not in all_dtypes + + def test_info_default_dtypes(self): + f32_device = xp.Device("F32_device") + info = xp.__array_namespace_info__() + defaults = info.default_dtypes(device=f32_device) + assert defaults["real floating"] == xp.float32 + assert defaults["complex floating"] == xp.complex64 + + cpu_device = xp.Device() + info = xp.__array_namespace_info__() + defaults = info.default_dtypes(device=cpu_device) + assert defaults["real floating"] == xp.float64 + assert defaults["complex floating"] == xp.complex128 diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 050f2bc..7fb6e33 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -6,7 +6,7 @@ from .. import asarray, _elementwise_functions -from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._devices import ALL_DEVICES, CPU_DEVICE, Device from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( _dtype_categories, @@ -21,6 +21,7 @@ int64, uint64, ) +from .._info import __array_namespace_info__ from .test_array_object import _check_op_array_scalar, BIG_INT import array_api_strict @@ -144,6 +145,10 @@ def _array_vals(dtypes): yield asarray(1., dtype=dtype, device=device) dtypes = _dtype_categories[types] + + supported_dtypes = __array_namespace_info__().dtypes(device=device) + dtypes = [dt for dt in dtypes if dt in supported_dtypes] + func = getattr(_elementwise_functions, func_name) for x in _array_vals(dtypes): diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index abe1949..18775ed 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,10 +3,11 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags -from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._devices import ALL_DEVICES, CPU_DEVICE, Device from .._dtypes import _all_dtypes + def test_where_with_scalars(): x = xp.asarray([1, 2, 3, 1])