Open jakevdp opened 3 hours ago
(to anticipate one response: no, it's not possible to make JAX arrays support mutation: central to JAX are transformations like jit
, vmap
, grad
, etc. that rely on immutability assumptions in their program tracing)
For (3), could you prototype what it would look like in the case of gh-609? For capabilities["mutable arrays"] == True
, we use the NumPy syntax x[i] += y
. For capabilities["mutable arrays"] == False
, we use ...? This would require standardising a way to do this for immutable arrays, right? Or can we just use xp.where
?
Several parts of the Array API standard assume that array objects are mutable.
This is very surprising. It would be nice if we can have a list of such occurrences here, because this was not supposed to happen as per our design guideline https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html
Several parts of the Array API standard assume that array objects are mutable. Some array API implementations (notably JAX) do not support mutating array objects. This has led to array API implementations currently being developed in
scipy
andsklearn
to be entirely unusable in JAX.Given this, downstream implementations have a few choices:
jax.numpy.Array
, changing the implementation logic for that case.(1) is a bad choice, because it means JAX will not be supported. (2) is a bad choice, because for libraries like NumPy, it leads to excessive copying of buffers, worsening performance. (3) is a bad choice because it hard-codes the presence of specific implementations in a context that is supposed to be implementation-agnostic.
One way the Array API standard could address this is by adding
"mutable arrays"
or something similar to the existingcapabilities
dict. Then downstream implementations could use strategy (3) without special-casing particular implementations.