Closed sh0416 closed 1 month ago
I installed from source, FYI.
Huhh. What jax and jaxlib versions?
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='gpusystem', release='5.15.0-100-generic', version='#110-Ubuntu SMP Wed Feb 7 13:27:48 UTC 2024', machine='x86_64')
This one.. seems that the problem is in jax, not your code.
Okie, I'll do some testing, should be easy to fix.
Tested this on 0.4.28, this should work now.
I got this error while testing with pytest. Is there any simple solution to resolve this error?