data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
69 stars 22 forks source link

Signatures of `get_xp`-wrapped functions? #153

Open mdhaber opened 2 months ago

mdhaber commented 2 months ago

There seem to be some issues with the signatures of functions wrapped by get_xp. I haven't narrowed down the exact problem, but here's an MRE:

import cupy as xp
from array_api_compat import cupy as xp_compat

A = xp.eye(3)
A = xp.asarray(A)

xp.linalg.eigh(A)  # fine
xp.linalg.eigh(a=A)  # fine

xp_compat.linalg.eigh(A)  # fine
xp_compat.linalg.eigh(a=A)  # error
# TypeError: eigh() missing 1 required positional argument: 'x'

Also, e.g.

xp.linalg.eigh(A, 'U')  # fine
xp_compat.linalg.eigh(A, 'U')  # error
TypeError: eigh() got multiple values for argument 'xp'
asmeurer commented 2 months ago

The problem is straightforward. eigh is defined as

https://github.com/data-apis/array-api-compat/blob/ac15c526d9769f77c780958a00097dfd183a2a37/array_api_compat/common/_linalg.py#L45-L46

In other words, it passes **kwargs through but doesn't pass *args through. This is easy enough to fix

diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..01db3a0 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):

 # These functions are the same as their NumPy counterparts except they return
 # a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
-    return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(x: ndarray, /, *args, xp, **kwargs) -> EighResult:
+    return EighResult(*xp.linalg.eigh(x, *args, **kwargs))

 def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
        **kwargs) -> QRResult:

We just need to do this for every function for which NumPy supports more positional arguments than the standard.

As for the other issue, eigh(a=A) you are calling eigh using the argument name that is positional-only in the standard.

We could make this work by instead defining all functions like

diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..11b54bb 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):

 # These functions are the same as their NumPy counterparts except they return
 # a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
-    return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(*args, xp, **kwargs) -> EighResult:
+    return EighResult(*xp.linalg.eigh(*args, **kwargs))

 def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
        **kwargs) -> QRResult:

The downside of this is it would completely kill introspectability of functions (right now I have it set up so that help(array_api_compat.numpy.linalg.eigh) shows the actual arguments).

Note that both of these examples are not portable with the standard, which defines the signature as

eigh(x: array, /) 
mdhaber commented 2 months ago

Note that both of these examples are not portable with the standard

Yeah, the thing is that this came up in the context of dispatching calls to SciPy's eigh function to other backends. It's unclear ATM what we want to do when the SciPy function has a much more flexible signature (including many other arguments) than the standard.

That said, my impression from here was that array_api_compat did not intend to limit capabilities of the wrapped libraries to those of the standard, so I went ahead and reported it.

asmeurer commented 2 months ago

Yes, in principle we should support this. Maybe I can modify the get_xp decorator to keep the standard signature for introspection purposes, but always pass through *args and **kwargs automatically. I'll need to think a bit about it.