Closed zwei-beiner closed 2 months ago
You're right! I was missing this case in my tests and it is indeed broken.
I fixed this now (and expanded the tests). Can you reinstall the package now (version 0.4.11
) and let me know if it works for you, please?
The new JAX release seems to work, it was a deeper problem, but it should be now fixed.
Thanks a lot, the code works now. Again, thanks for the quick response and making this package!
Awesome!
Hi, I came across the following bugs:
When I install the current up-to-date version of
torch2jax
withI get the following error
Potentially, this is coming from the multi-gpu update. I'm running this on a machine with no GPU, only CPU. Possibly the CPU-only case is not handled by the multi-gpu code?
So I installed the following commit (pre-multi-gpu):
Now, the first bug disappears (as expected), but I now get the following bug:
Maybe this is related to the latest breaking change in jax 0.4.32 (see the point under "Breaking changes" here: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-32).
System info: Ran everything in a clean venv with jax and torch installed from scratch with python 3.11. Ran the example code in the README file: