Skip to content
248 changes: 140 additions & 108 deletions array_api_compat/common/_aliases.py

Large diffs are not rendered by default.

87 changes: 45 additions & 42 deletions array_api_compat/common/_fft.py
Original file line number Diff line number Diff line change
@@ -1,149 +1,148 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Union, Optional, Literal
from collections.abc import Sequence
from typing import Union, Optional, Literal

if TYPE_CHECKING:
from ._typing import Device, ndarray, DType
from collections.abc import Sequence
from ._typing import Device, Array, DType, Namespace

# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.

def fft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This literal is repeated quite often, so it's probably a good idea to extract it as something like

_Norm: TypeAlias = Literal["backward", "ortho", "forward"]

) -> ndarray:
) -> Array:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def fftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 44 to 45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

FYI, a Sequence[int] is very broad, and will even accept things like bytes. Maybe a tuple[int, ...] is more appropriate here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 58 to 59
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def rfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def rfftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 100 to 101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 114 to 115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def hfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.float32)
return res

def ihfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
Expand All @@ -152,12 +151,12 @@ def ihfft(
def fftfreq(
n: int,
/,
xp,
xp: Namespace,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None
) -> ndarray:
device: Optional[Device] = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.fftfreq(n, d=d)
Expand All @@ -168,23 +167,27 @@ def fftfreq(
def rfftfreq(
n: int,
/,
xp,
xp: Namespace,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None
) -> ndarray:
device: Optional[Device] = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.rfftfreq(n, d=d)
if dtype is not None:
return res.astype(dtype)
return res

def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
def fftshift(
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None

) -> Array:
return xp.fft.fftshift(x, axes=axes)

def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
def ifftshift(
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None

) -> Array:
return xp.fft.ifftshift(x, axes=axes)

__all__ = [
Expand Down
32 changes: 17 additions & 15 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union, Any
from ._typing import Array, Device, Namespace

import sys
import math
import inspect
import warnings
from typing import Optional, Union, Any

from ._typing import Array, Device, Namespace


def _is_jax_zero_gradient_array(x: object) -> bool:
"""Return True if `x` is a zero-gradient array.
Expand Down Expand Up @@ -268,7 +266,7 @@
return __name__.removesuffix('.common._helpers')


def is_numpy_namespace(xp) -> bool:
def is_numpy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a NumPy namespace.

Expand All @@ -289,7 +287,7 @@
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}


def is_cupy_namespace(xp) -> bool:
def is_cupy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a CuPy namespace.

Expand All @@ -310,7 +308,7 @@
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}


def is_torch_namespace(xp) -> bool:
def is_torch_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.

Expand All @@ -331,7 +329,7 @@
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp) -> bool:
def is_ndonnx_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is an NDONNX namespace.

Expand All @@ -350,7 +348,7 @@
return xp.__name__ == 'ndonnx'


def is_dask_namespace(xp) -> bool:
def is_dask_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a Dask namespace.

Expand All @@ -371,7 +369,7 @@
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}


def is_jax_namespace(xp) -> bool:
def is_jax_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a JAX namespace.

Expand All @@ -393,7 +391,7 @@
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}


def is_pydata_sparse_namespace(xp) -> bool:
def is_pydata_sparse_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a pydata/sparse namespace.

Expand All @@ -412,7 +410,7 @@
return xp.__name__ == 'sparse'


def is_array_api_strict_namespace(xp) -> bool:
def is_array_api_strict_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is an array-api-strict namespace.

Expand All @@ -439,7 +437,11 @@
raise ValueError("Only the 2024.12 version of the array API specification is currently supported")


def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace:
def array_namespace(
*xs: Union[Array, bool, int, float, complex, None],

Check failure on line 441 in array_api_compat/common/_helpers.py

View workflow job for this annotation

GitHub Actions / check-ruff

array_api_compat/common/_helpers.py:441:10: SyntaxError: Cannot use star annotation on Python 3.9 (syntax was added in Python 3.11)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's sufficient to write

Suggested change
*xs: Union[Array, bool, int, float, complex, None],
*xs: Array | complex | None,

Keep in mind that in case of type-checker errors, often the entire signature is printed. So it's usually a good idea to avoid long annotations (in function signatures, but also in general).

api_version: Optional[str] = None,
use_compat: Optional[bool] = None,
) -> Namespace:
"""
Get the array API compatible namespace for the arrays `xs`.

Expand Down
Loading
Loading