Open Aki1991 opened 2 months ago
Same issue there.
AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'
can be solved by changing jax.random.PRNGKeyArray
with jax.Array
.
But it is not solving
flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
This should be fixed in ott-jax==0.3.1
I am using same version of ott-jax==0.3.1, still same error.
sorry, it should be running "pip install ott-jax==0.4.5" firstly, if you have an error about "transport" then run "pip install ott-jax==0.3.1"
Yes I am getting the "transport" error, that's why I am using ott-jax==0.3.1. And that leads to the error:
flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
pip install ott-jax==0.2.0 works
Hi all, I am trying to fine tune our model using owl_vit model.
But when I try to run it, I get this error,
AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'
. Jax version I am using is 0.4.30. If I use jax 0.4.23, it works but then it is not using GPU while training which slows down the training a lot. Is there a way I can use 0.4.30 version of jax and solve this error?If I change the PRNGKeyArray with key, at later stage I get an error,
Can anyone suggest what can I do here? Thank you.
UPDATE: I installed all libraries with proper versions and made it work with GPU with jax==0.4.23 but I am still getting the error mentioned above,