data-apis / array-api-strict

Strict implementation of the Python array API (previously numpy.array_api)
http://data-apis.org/array-api-strict/
Other
7 stars 4 forks source link

array-api-strict creates an empty iterable rather than raising an error #41

Closed j-bowhay closed 3 months ago

j-bowhay commented 3 months ago

For example:

In [47]: import array_api_strict as xp

In [48]: x = xp.ones((2,2))

In [49]: list(iter(x))
Out[49]: []

I think this is because __getitem__ raises a IndexError

In [47]: import array_api_strict as xp

In [48]: x = xp.ones((2,2))

In [49]: list(iter(x))
Out[49]: []

In [50]: x.__getitem__(0)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[50], line 1
----> 1 x.__getitem__(0)

File ~/miniconda3/envs/scipy-dev-pytorch/lib/python3.11/site-packages/array_api_strict/_array_object.py:588, in Array.__getitem__(self, key)
    583 """
    584 Performs the operation __getitem__.
    585 """
    586 # Note: Only indices required by the spec are allowed. See the
    587 # docstring of _validate_index
--> 588 self._validate_index(key)
    589 if isinstance(key, Array):
    590     # Indexing self._array with array_api_strict arrays can be erroneous
    591     key = key._array

File ~/miniconda3/envs/scipy-dev-pytorch/lib/python3.11/site-packages/array_api_strict/_array_object.py:374, in Array._validate_index(self, key)
    370 elif n_ellipsis == 0:
    371     # Note boolean masks must be the sole index, which we check for
    372     # later on.
    373     if not key_has_mask and n_single_axes < self.ndim:
--> 374         raise IndexError(
    375             f"{self.ndim=}, but the multi-axes index only specifies "
    376             f"{n_single_axes} dimensions. If this was intentional, "
    377             "add a trailing ellipsis (...) which expands into as many "
    378             "slices (:) as necessary - this is what np.ndarray arrays "
    379             "implicitly do, but such flat indexing behaviour is not "
    380             "specified in the Array API."
    381         )
    383 if n_ellipsis == 0:
    384     indexed_shape = self.shape

IndexError: self.ndim=2, but the multi-axes index only specifies 1 dimensions. If this was intentional, add a trailing ellipsis (...) which expands into as many slices (:) as necessary - this is what np.ndarray arrays implicitly do, but such flat indexing behaviour is not specified in the Array API.

which is meant for out of bounds, instead I think we should be raising a TypeError. image

This can be reproduced also in the following example

In [51]: class A:
    ...:     def __getitem__(self, key):
    ...:         raise IndexError
    ...:

In [52]: list(iter(A()))
Out[52]: []
asmeurer commented 3 months ago

Ah, good catch. I think I'd rather keep the IndexError for __getitem__ and just manually implement __iter__ that raises TypeError.