google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.85k stars 786 forks source link

Issue with GPU allocation (only occurs when running SAC not PPO) #1431

Closed AlexS28 closed 7 months ago

AlexS28 commented 7 months ago

Hi, I am trying to run my environment using MJX, and everything works perfectly fine when I use PPO. When using PPO I can easily go to over 8000 environments, with batch size over 1000 without any memory issues. However, no matter what setting I try when running SAC, it becomes a different story (even with a batch size of 32, and 1 environment it fails). I just cannot avoid running into the memory allocation issue copy/pasted below. Any help would be greatly appreciated, because for my particular environment (through trying it using other simulators) it can only really train using the SAC algorithm not PPO, so I really do need to use SAC. Any help would be greatly appreciated.

My computer has a Nvidia 4090 GPU.

Running this training: train_fn = functools.partial( sac.train, num_timesteps=30_000_000, num_evals=1, episode_length=INITIAL_PARAMS.RL_PARAMS.MAX_EPISODE_TIMESTEPS, normalize_observations=True, action_repeat=1, num_envs=1, batch_size = 32, seed=0)

Getting this error: 2024-02-19 15:14:56.882326: W external/xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below -30.29GiB (-32529110990 bytes) by rematerialization; only reduced to 49.62GiB (53280000080 bytes), down from 49.62GiB (53280000080 bytes) originally 2024-02-19 15:15:06.899957: W external/tsl/tsl/framework/bfc_allocator.cc:487] Allocator (GPU_0_bfc) ran out of memory trying to allocate 49.62GiB (rounded to 53280000000)requested by op 2024-02-19 15:15:06.900601: W external/tsl/tsl/framework/bfcallocator.cc:499] ____**_____ E0219 15:15:06.900728 6066 pjrt_stream_executor_client.cc:2766] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 8B constant allocation: 8B maybe_live_out allocation: 49.62GiB preallocated temp allocation: 160B total allocation: 49.62GiB total fragmentation: 168B (0.00%) Peak buffers: Buffer 1: Size: 49.62GiB Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112 XLA Label: fusion Shape: f32[30000000,444]

Buffer 2:
    Size: 32B
    XLA Label: tuple
    Shape: (f32[30000000,444], s32[], s32[], u32[2])
    ==========================

Buffer 3:
    Size: 32B
    XLA Label: tuple
    Shape: (f32[30000000,444], s32[], s32[], u32[2])
    ==========================

Buffer 4:
    Size: 16B
    XLA Label: fusion
    Shape: (s32[], s32[])
    ==========================

Buffer 5:
    Size: 8B
    Entry Parameter Subshape: u32[2]
    ==========================

Buffer 6:
    Size: 8B
    XLA Label: fusion
    Shape: u32[2]
    ==========================

Buffer 7:
    Size: 4B
    XLA Label: fusion
    Shape: s32[]
    ==========================

Buffer 8:
    Size: 4B
    XLA Label: fusion
    Shape: s32[]
    ==========================

Buffer 9:
    Size: 4B
    XLA Label: constant
    Shape: f32[]
    ==========================

Buffer 10:
    Size: 4B
    XLA Label: constant
    Shape: s32[]
    ==========================

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/sclaer/ASCENT_biped/sim/mujoco/rl_mjx.py", line 367, in make_inferencefn, params, = train_fn(environment=env, progress_fn=progress) File "/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/agents/sac/train.py", line 429, in train buffer_state = jax.pmap(replay_buffer.init)( jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 8B constant allocation: 8B maybe_live_out allocation: 49.62GiB preallocated temp allocation: 160B total allocation: 49.62GiB total fragmentation: 168B (0.00%) Peak buffers: Buffer 1: Size: 49.62GiB Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112 XLA Label: fusion Shape: f32[30000000,444]

Buffer 2:
    Size: 32B
    XLA Label: tuple
    Shape: (f32[30000000,444], s32[], s32[], u32[2])
    ==========================

Buffer 3:
    Size: 32B
    XLA Label: tuple
    Shape: (f32[30000000,444], s32[], s32[], u32[2])
    ==========================

Buffer 4:
    Size: 16B
    XLA Label: fusion
    Shape: (s32[], s32[])
    ==========================

Buffer 5:
    Size: 8B
    Entry Parameter Subshape: u32[2]
    ==========================

Buffer 6:
    Size: 8B
    XLA Label: fusion
    Shape: u32[2]
    ==========================

Buffer 7:
    Size: 4B
    XLA Label: fusion
    Shape: s32[]
    ==========================

Buffer 8:
    Size: 4B
    XLA Label: fusion
    Shape: s32[]
    ==========================

Buffer 9:
    Size: 4B
    XLA Label: constant
    Shape: f32[]
    ==========================

Buffer 10:
    Size: 4B
    XLA Label: constant
    Shape: s32[]
    ==========================
kinalmehta commented 7 months ago

You are not setting the max_replay_size parameter and if not specified, Brax uses num_timesteps as the max_replay_size which is 30_000_000 in your case. So a replay buffer of 30 million is being initialized which captures the VRAM. Try specifying max_replay_size to a smaller value and things should work.

AlexS28 commented 7 months ago

Awesome, thanks my RL is working now with that change.