google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

very very slow on local computer (even with GPU) #491

Closed ucacaxm closed 2 weeks ago

ucacaxm commented 2 weeks ago

Hi all,

I test the tutorial in my local computer (Ubuntu 24.04, RTX 3060, driver 535.171.04) https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb

On collab, it takes 10min. On my computer, I stopped it after 10h without any result. With nvtop, I see the GPU memory is around 10Gb but the GPU usage is around 2%.

I added

  from jax.lib import xla_bridge
  print(xla_bridge.get_backend().platform)

and it prints "gpu" => so it should use the GPU

I left the parameters like this

train_fn = functools.partial(
    ppo.train, num_timesteps=30_000_000, num_evals=5, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=32, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048,
    batch_size=1024, seed=0)

I have this message (as in the post https://github.com/google/brax/issues/488)

    024-06-13 18:33:17.198373: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.2 
    which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling 
    parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-
    provided  CUDA forward compatibility packages.

but it is only for compilation, no ? after that the running should be fast ?

Any idea from where is the problem ?

Thanks,

PS: My install is: RTX 3060, driver 535.171.04

brax 0.10.5 jax 0.4.29 jax-cuda12-pjrt 0.4.29 jax-cuda12-plugin 0.4.29 jaxlib 0.4.29 jaxopt 0.8.3 mujoco 3.1.6 mujoco-mjx 3.1.6 nvidia-cublas-cu12 12.5.2.13 nvidia-cuda-cupti-cu12 12.5.39 nvidia-cuda-nvcc-cu12 12.5.40 nvidia-cuda-runtime-cu12 12.5.39 nvidia-cudnn-cu12 9.1.1.17 nvidia-cufft-cu12 11.2.3.18 nvidia-cusolver-cu12 11.6.2.40 nvidia-cusparse-cu12 12.4.1.24 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.5.40

ucacaxm commented 2 weeks ago

Sorry, I get my answer. The execution was blocked in the plot (I was a bit stupid)

ucacaxm commented 2 weeks ago

The GPU is high and execution time fast