Open mabilton opened 9 months ago
Hi Matt,
That's right, the numpy functionality will not work properly if jax is installed. This is something we are aware of (and only cover the other cases in the test suite) but to some extent we are trying to decide whether this should be considered the correct behavior or not. And of course whatever decision we make, that should be detailed more prominently in the docs.
A fair amount of the functionality will not work with numpy (e.g. autograd enabled transposes, vmapped linear operator for unary function application), so if jax is properly installed we were figuring that the user would want that functionality enabled even if they were just using numpy arrays to construct their operators. In some ways we are still trying to think about how we want to position the limited numpy support.
Do you have a sense for some important use cases where one would want to use numpy even when jax is installed?
Hey @mfinzi.
I certainly agree that for the majority of users, it makes sense to 'automatically switch' to the jax
backend from the numpy
backend when possible. In saying that, would there be any harm in allowing users to explicitly disable this 'automatic switching', particularly if doing so would allow us to write tests that run a bit more consistently?
As a (somewhat contrived) example of where this 'automatic backend switching' might cause trouble, suppose another library chose to use cola
to perform some linear algebra operations, but also decided to implement 'jax
-unfriendly' computations (e.g. computations involving dynamic array shapes, or performing an in-place array update) later on in their code. If the maintainers of this library don't have jax
installed, everything will appear fine when they use numpy
arrays. If a user of this library, however, also happens to have jax
installed on their system, the later 'jax
-unfriendly' steps would likely throw an error (unless the maintainers of the library explicitly converted everything back to np.array
s after the steps involving cola
, which might be a bit tedious). As I say, a bit of a silly example, but I think the general point about 'silently' converting numpy
arrays to jax
arrays still stands.
Just a couple of other quick thoughts:
get_library_fns
, I think it would be useful to add a brief comment explaining that the jax
backend will be returned if a numpy
dtype is passed but jax
is installed, since I don't think this is obvious from the function definition.numpy
tests failing if jax
is installed), it might be worth checking out whether something like pytest.mark.skipif
can be used to automatically 'turn off' those tests that are expected to fail.Thanks for your help on this.
Cheers, Matt.
🐛 Bug
Some of the
numpy
backend unit tests fail ifjax
is installed, but pass whenjax
is not installed (i.e. these tests are 'flaky').To reproduce
Test results when
jax
is not installed:Test results when
jax
is installed:Note that similar results are observed when
pytest -m "numpy" -k "test_get_lu_from_tridiagonal"
is run.Expected Behavior
Ideally, unit tests should run in a predictable and consistent manner, with the result of a given test not depending on which optional dependencies the user may or may not have installed on their machine.
System information
jax
version: 0.4.16Additional context
I encountered this issue when running the test suite for the first time before starting work on #75. It appears that the current CI workflow doesn't 'pick-up' on this problem because the
numpy
tests are only executed tests whenjax
is not installed.From my own experiments, it seems that the source of the flaky-ness in these
numpy
tests is thatcola.backends.get_library_fns
correctly infers the back-end of anumpy
array to benumpy_fns
whenjax
is not installed, but incorrectly infers the back-end to bejax_fns
whenjax
is installed. We can see why this occurs by considering the current implementation ofget_library_fns
:i.e.
get_library_fns
will infer the back-end to bejax
ifjax
can be imported and ifdtype
matches with ajax.numpy
type. Unfortunately, it turns out (much to my surprise) thatjax.numpy
types are basically just aliases fornumpy
types, which means that Python evaluatesjax.numpy
andnumpy
types as equal to one another:This means
get_library_fns
will always returnjax_fns
when provided with anumpy
array ifjax
is installed. Even more surprisingly, thedtype
property of ajax.numpy
array is not even guaranteed to be ajax.numpy
type:I think these observations illustrate that the 'premise' behind the
get_library_fns
function (i.e. that you can determine which back-end to use purely based on thedtype
property of an array) probably isn't sound.Proposed Solutions
Two potential fixes come to mind:
get_library_fns
function and replace it with a similar function that requires the user to explicitly name the back-end they wish to be returned.get_library_fns
to 'force' it to return thenumpy
backend, even whenjax
is installed; this flag can then be used during thenumpy
tests to ensure that they're consistent.I'm more than happy to work on this issue, but it would be great to hear what others think about all this first. Thanks in advance for any help.
Cheers, Matt.