data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
78 stars 26 forks source link

Add type hints / annotations to `array_namespace()`: `array_api_strict` as return type? #194

Open 34j opened 2 weeks ago

34j commented 2 weeks ago

Without type hints, it is very inconvenient because of spelling errors of array_api functions. As array_api_compat seems to be a superset of array_api_strict, I would like to propose to simply set the return type of array_namespace() to array_api_strict, although it might be confusing.

https://github.com/data-apis/array-api-compat/blob/c5ef3dc3d183f536378e53f93f41f1f474780844/array_api_compat/common/_helpers.py#L424

from typing import overload
import array_api_strict

array_api_compat_type = array_api_strict

@overload
def array_namespace(*xs, api_version=None, use_compat: None=None) -> array_api_compat_type:
    ...
asmeurer commented 1 week ago

I think upstream work to make a protocol out of the array API namespace would be what we'd want here.

https://github.com/data-apis/array-api/pull/685 https://github.com/data-apis/array-api/issues/267

Setting the return type as array_api_strict seems wrong. array_api_strict isn't a type, it's a module, and it isn't even the module that's returned in most cases. The correct return type would be typing.ModuleType, which we can definitely add, although I don't think that alone would make it infer much.