google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.23k stars 246 forks source link

TPU training colab is not working for me #224

Open ViktorM opened 2 years ago

ViktorM commented 2 years ago

TPU training colab stopped working for me after one of the updates in recent months. It stalls when comes to the training cell it shows that some work is going on but it never finishes, and no plots and trained policies are produced.

erwincoumans commented 1 year ago

Hi, tried this TPU training colab today, and it fails with this error:


AttributeError                            Traceback (most recent call last)
[<ipython-input-3-9ce5fdb19302>](https://localhost:8080/#) in <module>
     14 
     15 try:
---> 16   import brax
     17 except ImportError:
     18   get_ipython().system('pip install git+https://github.com/google/brax.git@main')

2 frames
[/usr/local/lib/python3.9/dist-packages/brax/jumpy.py](https://localhost:8080/#) in <module>
    504 
    505 
--> 506 def where(condition: jax.typing.ArrayLike, x: jax.typing.ArrayLike,
    507           y: jax.typing.ArrayLike) -> ndarray:
    508   """Return elements chosen from `x` or `y` depending on `condition`."""

AttributeError: module 'jax' has no attribute 'typing'
btaba commented 1 year ago

Hi @erwincoumans! I'm not able to reproduce the issue, which jax version are you using?

erwincoumans commented 1 year ago

I'm just running the public colab, following the links on the github front page. Did you try training it using a public TPU runtime? The other public colab (training using PyTorch) is also still broken, see the other issue.

image

https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb

Just tried it again, here is the output:


---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-1-9ce5fdb19302>](https://localhost:8080/#) in <module>
     15 try:
---> 16   import brax
     17 except ImportError:

ModuleNotFoundError: No module named 'brax'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
3 frames
[/usr/local/lib/python3.9/dist-packages/brax/jumpy.py](https://localhost:8080/#) in <module>
    504 
    505 
--> 506 def where(condition: jax.typing.ArrayLike, x: jax.typing.ArrayLike,
    507           y: jax.typing.ArrayLike) -> ndarray:
    508   """Return elements chosen from `x` or `y` depending on `condition`."""

AttributeError: module 'jax' has no attribute 'typing'
btaba commented 1 year ago

Ok thanks for the pointer! It turns out that jax>=0.4.6 is incompatible with public colab TPU runtimes (see https://stackoverflow.com/a/75734517). We're pinning the jax/jaxlib versions to >=0.4.6 now, so it's best to run in a GPU runtime for the time being until the colab issue is fixed I've confirmed training works on GPU in a public colab runtime