Open phinate opened 1 year ago
Hello, thanks I was pretty pleased with it too :)
There's some relevant discussion in https://github.com/google/jax/issues/17107#issuecomment-1686497820. Essentially we're not officially supporting it outside the GraphCast project for now, as it does have some rough edges (see the thread for some examples) and it is in part a bit of a stop-gap measure until JAX supports the new python array API standard which will allow it to integrate better with xarray.
If you wanted to speed things along here, one really valuable thing could be working with the JAX folks to push along support for the array API in JAX. From what I gather one of the remaining obstacles is figuring out how to handle the mutable parts of the array API, since JAX doesn't support mutation. (One thing I've done in JaxArrayWrapper here is just to drop / ignore the out=
kwargs which show up when ufuncs are run in-place, meaning it returns a fresh array and the mutation doesn't happen in-place. However this isn't perfect and it might be better to throw an error; I'm also not sure if the same approach would work for the new array API). IMO it would be OK just to omit support for in-place mutation.
@shoyer may be a good person to talk to about this (or might at least know who to ask) as he has a good visibility of both xarray and jax projects.
In the long run I'd definitely like to improve this and make a standalone release, although I was waiting for the array API implementation in JAX before investing a lot more in it as a stand-alone project, as it could change the approach required quite a bit, simplify it a bit and remove some (not all) of the rough edges.
If you wanted to maintain your own fork in the meantime I won't object and others may find it useful, although I can't promise we'd switch to or co-maintain it ourselves if that proved disruptive to our current research. By the nature of these things with slightly rough edges, if you make changes to the rough edges it can end up breaking things downstream in slightly subtle and annoying ways. (We could admittedly also use more test coverage for some of these things). But would be interested to follow what you do anyway and could maybe join up with it at some later date.
Hope that makes sense anyway and thanks for volunteering!
Just coming back to this now -- thanks so much @mjwilson for your response!
Conforming to the Array API standard definitely makes sense before investing too much effort, so totally with you on that (and on not supporting mutation: seems like this isn't part of the array API standard anyway!). Had a shop round the existing issues, and it seems like there's a couple roadblocks here and there, but they looked a little stale. Will look at finding an entrypoint there, I think it would be great for this to go through (even without the xarray support!).
RE: maintaining a fork -- totally makes sense. We're not 100% concrete on our use case or if we'll have the engineering effort to maintain something, but will definitely be trialling the great work you've done here. If, however, you start to push on making this more mature at any point, I'd definitely be interested in lending a hand :)
Thanks again for your thorough answer, and hope to talk again on this! I'll leave the issue open for the time being since there's potential directions here, but feel free to close if you'd rather keep it clean until other things pop up.
Hello! So now that Jax has support for the Array API https://github.com/google/jax/issues/18353, I'm wondering if there is any talk of implementing tighter integration between the two libraries?
Hi! I'm a big fan of the boilerplate code used to wrap
xarray
into a JAX-compatible entity -- I think this could have wider usage were it more well-known, especially for this kind of deep learning + weather project.Would you consider distributing this code as a stand-alone helper module? I'm happy to volunteer refactoring this code into a small library, since I'm probably going to use it anyway -- let me know if you'd like me to take the initiative on that (and where it should live).