google-research / scenic

Scenic: A Jax Library for Computer Vision Research and Beyond
Apache License 2.0
3.3k stars 432 forks source link

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' #1089

Open Aki1991 opened 2 months ago

Aki1991 commented 2 months ago

Hi all, I am trying to fine tune our model using owl_vit model.

But when I try to run it, I get this error, AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'. Jax version I am using is 0.4.30. If I use jax 0.4.23, it works but then it is not using GPU while training which slows down the training a lot. Is there a way I can use 0.4.30 version of jax and solve this error?

If I change the PRNGKeyArray with key, at later stage I get an error,

Traceback (most recent call last):
  File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 61, in <module>
    app.run(main=main)
  File "/home/user/Akash/Owl/scenic/scenic/app.py", line 68, in run
    app.run(functools.partial(_run_main, main=main))
  File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/user/Akash/Owl/scenic/scenic/app.py", line 109, in _run_main
    main(rng=rng, config=config, workdir=workdir, writer=writer)
  File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 51, in main
    trainer.train(
  File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/trainer.py", line 218, in train
    gflops) = train_utils.initialize_model(
  File "/home/user/Akash/Owl/scenic/scenic/train_lib/train_utils.py", line 187, in initialize_model
    flops = debug_utils.compute_flops(
  File "/home/user/Akash/Owl/scenic/scenic/common_lib/debug_utils.py", line 139, in compute_flops
    flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

Can anyone suggest what can I do here? Thank you.

UPDATE: I installed all libraries with proper versions and made it work with GPU with jax==0.4.23 but I am still getting the error mentioned above,

    flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
dr4thmos commented 2 months ago

Same issue there.

Aki1991 commented 2 months ago

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' can be solved by changing jax.random.PRNGKeyArray with jax.Array.

But it is not solving

flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
LihanWa commented 2 months ago

This should be fixed in ott-jax==0.3.1

Aki1991 commented 2 months ago

I am using same version of ott-jax==0.3.1, still same error.

LihanWa commented 2 months ago

sorry, it should be running "pip install ott-jax==0.4.5" firstly, if you have an error about "transport" then run "pip install ott-jax==0.3.1"

Aki1991 commented 2 months ago

Yes I am getting the "transport" error, that's why I am using ott-jax==0.3.1. And that leads to the error:

flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
thecho7 commented 2 months ago

pip install ott-jax==0.2.0 works