google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.15k stars 236 forks source link

nvidia GPU acceleration question #344

Open HowdyMoto opened 1 year ago

HowdyMoto commented 1 year ago

I have a machine set up with the following: nvidia RTX 4090 GPU Fresh install of Ubuntu 22.04 5.30 nvidia drivers with CUDA 12.1

I set up Brax with pip install -e . After doing so, learn works fine using CPU.

I then set up GPU-accelerated jax with: pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

When I run learn with no flags, I see the below output, which looks like it searches for tpu and rocm, says CUDA is available but doesn't explicitly indicate it's using CUDA. If I let this run, it's painfully slow - it will run for many hours with few updates. learn --helpfull doesn't explicitly show any flags that I need to set to use CUDA and CUDNN - am I missing any steps to get it working? I know this should be extremely fast with this GPU.

åI0423 20:06:11.578360 140285023629312 metrics.py:42] Hyperparameters: {'num_evals': 10, 'num_envs': 4, 'total_env_steps': 50000000} I0423 20:06:11.665900 140285023629312 xla_bridge.py:440] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter Host CUDA I0423 20:06:11.666139 140285023629312 xla_bridge.py:440] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client' I0423 20:06:11.666195 140285023629312 xla_bridge.py:440] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this. I0423 20:06:13.768080 140285023629312 train.py:107] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1 I0423 20:06:41.432691 140285023629312 train.py:320] {'eval/walltime': 22.236454486846924, 'eval/episode_distance_from_origin': Array(48.86863, dtype=float32), 'eval/episode_forward_reward': Array(0.9787996, dtype=float32), 'eval/episode_reward': Array(-8.558953, dtype=float32), 'eval/episode_reward_contact': Array(0., dtype=float32), 'eval/episode_reward_ctrl': Array(-49.092438, dtype=float32), 'eval/episode_reward_forward': Array(0.9787996, dtype=float32), 'eval/episode_reward_survive': Array(39.554688, dtype=float32), 'eval/episode_x_position': Array(0.007761, dtype=float32), 'eval/episode_x_velocity': Array(0.9787996, dtype=float32), 'eval/episode_y_position': Array(-5.1503973, dtype=float32), 'eval/episode_y_velocity': Array(-1.67458, dtype=float32), 'eval/avg_episode_length': Array(39.554688, dtype=float32), 'eval/epoch_eval_time': 22.236454486846924, 'eval/sps': 5756.3133581261} I0423 20:06:41.434379 140285023629312 metrics.py:51] [0] eval/avg_episode_length=39.5546875, eval/episode_distance_from_origin=48.868629455566406, eval/episode_forward_reward=0.97879958152771, eval/episode_reward=-8.558953285217285, eval/episode_reward_contact=0.0, eval/episode_reward_ctrl=-49.092437744140625, eval/episode_reward_forward=0.97879958152771, eval/episode_reward_survive=39.5546875, eval/episode_x_position=0.0077610015869140625, eval/episode_x_velocity=0.97879958152771, eval/episode_y_position=-5.150397300720215, eval/episode_y_velocity=-1.6745799779891968, eval/epoch_eval_time=22.236454, eval/sps=5756.313358, eval/walltime=22.236454 I0423 20:06:41.435662 140285023629312 train.py:326] starting iteration 0 27.667602062225342

erikfrey commented 1 year ago

Hello @HowdyMoto - you'll want to modify the hparams to work for your given machine setup. In your case, you want num_envs to be much higher, 2048 or 4096. That is probably the cause of your slowness. Check the training colab for some hparams that work with an accelerator:

https://colab.sandbox.google.com/github/google/brax/blob/main/notebooks/training.ipynb