Closed betatim closed 1 month ago
+1! Another question is what the default should be (technically Device("pony")
is more strict), but probably better if we can keep the cpu default for backwards compatibility.
I think the CPU device should be the default. That way code that exists today should keep working and the only people who notice any changes are those who use the pony device.
This looks good so far. We need to make sure the semantics specified at https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics are followed, namely, disallowing combining arrays from different devices, and making sure that if a function creates a new array based on an existing array that it uses the same device.
For tests, ideally this would be tested in array-api-tests, but right now device support is not tested at all there. If you just want to add some basic tests here for now, that is fien.
Finally, there is the devices
inspection API. https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.devices.html#array_api.info.devices We need to think about how that will work. One option would just be to create a small but fixed number of devices. Or we could add some flags to make it configurable https://data-apis.org/array-api-strict/api.html#array_api_strict.set_array_api_strict_flags
I've rebooted my work. The long pause is because I went on holiday :D
I think we have to limit ourselves to a fixed number of devices, otherwise we can't full fill the requirement that the info extension can provide a list of devices. So now you can use the CPU_DEVICE
and the (creatively named) device1
and device2
.
Slowly making progress towards the creation functions and "array combination" functions respecting the device
Do you use ruff
or black
or something like that for formatting?
It looks like it would be quite tricky to add device testing to array-api-tests
. At least at my level of knowledge (unfamiliar with hypothesis and array-api-tests
). It looks like you'd add a helper, maybe all_devices
, similar to all_dtypes
and then use it in the @given
decorator. The tricky thing is that how to specify a device depends on the library, in a new version of the standard you could use the inspection API to get all devices. So yeah, for now I might add some basic testing here.
Do you use ruff or black or something like that for formatting?
There's no autoformatting on this repo.
It looks like it would be quite tricky to add device testing to array-api-tests. At least at my level of knowledge (unfamiliar with hypothesis and array-api-tests). It looks like you'd add a helper, maybe all_devices, similar to all_dtypes and then use it in the @given decorator. The tricky thing is that how to specify a device depends on the library, in a new version of the standard you could use the inspection API to get all devices. So yeah, for now I might add some basic testing here.
I think it would have to use the devices()
function in the inspection API. That would mean that the tests would only work against the newest version of the standard and it would only work against the compat library, but I think that's fine. You'd also probably want to make it optional.
It's also possible to do some basic testing using the default device, like that x.device
and device=
are consistent.
The annoying thing for the test suite is making sure every function everywhere is passing device
through properly so that everything gets created on the same device. It would also probably require some upstream fixes to the hypothesis array-api support.
I think what we need here are just some big parameterized tests combining basic example arrays with different devices across all the APIs. For instance, there's an existing test that checks type promotion and the "no mixing devices" test could look very similar to that.
Does someone know more about the failure? It looks like it is not to do with the actual code but with computing the expected shape and that overflowing because the array is of dtype int8
. How do we fix that?
Sorry, that is from a new test that I added in the test suite. I guess I didn't catch all the corner cases. You can ignore it for now.
In that case, I think, this is ready?!
I tried to modify all the functions that return an array to take into account the device
of the input. For some I've added tests that check this, but I think really array-api-tests
should be doing this (check input and output device are consistent). SO maybe adding tests for all functions is not needed.
It's hard to tell just from the diff if you missed anything. Here are all the places in the code that call _new
without a device keyword
$ git grep -n '_new(' | grep -v 'device'
array_api_strict/_array_object.py:278: return Array._new(np.array(scalar, dtype=self.dtype._np_dtype))
array_api_strict/_array_object.py:310: x1 = Array._new(x1._array[None])
array_api_strict/_array_object.py:312: x2 = Array._new(x2._array[None])
array_api_strict/_array_object.py:496: return self.__class__._new(res)
array_api_strict/_array_object.py:684: return self.__class__._new(res)
array_api_strict/_array_object.py:1030: return self.__class__._new(res)
array_api_strict/_array_object.py:1042: return self.__class__._new(res)
array_api_strict/_array_object.py:1063: return self.__class__._new(res)
array_api_strict/_array_object.py:1084: return self.__class__._new(res)
array_api_strict/_array_object.py:1105: return self.__class__._new(res)
array_api_strict/_array_object.py:1149: return self.__class__._new(res)
array_api_strict/_array_object.py:1170: return self.__class__._new(res)
array_api_strict/_array_object.py:1191: return self.__class__._new(res)
array_api_strict/_array_object.py:1212: return self.__class__._new(res)
array_api_strict/_array_object.py:1282: return self.__class__._new(self._array.T)
array_api_strict/_creation_functions.py:82: return Array._new(new_array)
array_api_strict/_creation_functions.py:214: return Array._new(np.from_dlpack(x))
array_api_strict/_data_type_functions.py:57: Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])
array_api_strict/_data_type_functions.py:69: return Array._new(np.broadcast_to(x._array, shape))
array_api_strict/_linalg.py:62: U = Array._new(L).mT
array_api_strict/_linalg.py:66: return Array._new(L)
array_api_strict/_linalg.py:94: return Array._new(np.cross(x1._array, x2._array, axis=axis))
array_api_strict/_linalg.py:107: return Array._new(np.linalg.det(x._array))
array_api_strict/_linalg.py:119: return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
array_api_strict/_linalg.py:150: return Array._new(np.linalg.eigvalsh(x._array))
array_api_strict/_linalg.py:164: return Array._new(np.linalg.inv(x._array))
array_api_strict/_linalg.py:184: return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord))
array_api_strict/_linalg.py:200: return Array._new(np.linalg.matrix_power(x._array, n))
array_api_strict/_linalg.py:223: return Array._new(np.count_nonzero(S > tol, axis=-1))
array_api_strict/_linalg.py:243: return Array._new(np.outer(x1._array, x2._array))
array_api_strict/_linalg.py:262: return Array._new(np.linalg.pinv(x._array, rcond=rtol))
array_api_strict/_linalg.py:351: return Array._new(_solve(x1._array, x2._array))
array_api_strict/_linalg.py:375: return Array._new(np.linalg.svd(x._array, compute_uv=False))
array_api_strict/_linalg.py:400: return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)))
array_api_strict/_linalg.py:440: res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord))
array_api_strict/_searching_functions.py:46: return tuple(Array._new(i) for i in np.nonzero(x._array))
So it would be worth double checking all of those.
As you mentioned, we should indeed be testing most of this in the test suite. However, I'm not really sure how soon that will happen. There's quite a backlog of things to do in the test suite right now, and my current priority is implementing tests for new functions added in 2023.12 or 2024.12 versions of the standard. So some very simple tests here would not hurt. The tests already have a list of two-argument functions which could be reused.
Were you wanting to add support for devices that don't support certain dtypes, or is that something that we should add in a later pull request?
Here are all the places in the code that call _new without a device keyword
Actually, I think we should make device
a required argument to Array._new
. That way all APIs in array-api-strict are required to make sure they do proper device handling.
Were you wanting to add support for devices that don't support certain dtypes, or is that something that we should add in a later pull request?
I'd do that in a separate PR. If only because this one is already quite long and hard to check by looking at the diff.
Here are all the places in the code that call _new without a device keyword
Actually, I think we should make
device
a required argument toArray._new
. That way all APIs in array-api-strict are required to make sure they do proper device handling.
That is a good idea. I like it
Nice. I feel much better about this after the latest commit making device
required.
It looks like another PR I just merged has created a small conflict here, but other than that, I am +1 to merging this.
Thanks for taking this over the finish line!
I opened https://github.com/data-apis/array-api-strict/issues/70 for ideas for further work here.
Having more than one device is useful during testing to allow you to find bugs related to how arrays on different devices are handled. Closes #56
With scikit-learn we run into the frustrating situation were contributors execute tests locally, they all pass but then see failures on the CI related to the fact that e.g. PyTorch has several devices and some things work on the CPU device but not on the CUDA/MPS device. However, if you have neither of those on your local machine you can't really test this upfront and to debug it you need to rely on the CI.
The idea of this PR is to add support for multiple devices to
array-api-strict
to make testing easier. The default device continues to be the CPU device and for arrays that use it nothing should change. However, you can now place an array on a different device witharray_api_strict.Device("pony")
(or some other string, each string is a new device). For arrays on a device that isn't the CPU device calls likenp.asarray(some_strict_array)
will raise an error. This mirrors how PyTorch treats arrays on the CPU and MPS device.What isn't yet implemented in this PR is raising an error if you try to operate on arrays that are not on the same device.
I wanted to open this PR already now after just a short amount of effort to get feedback what people think about this before putting in the time to update all the tests, etc.