Closed AlexS28 closed 7 months ago
You are not setting the max_replay_size
parameter and if not specified, Brax uses num_timesteps
as the max_replay_size
which is 30_000_000
in your case. So a replay buffer of 30 million is being initialized which captures the VRAM. Try specifying max_replay_size
to a smaller value and things should work.
Awesome, thanks my RL is working now with that change.
Hi, I am trying to run my environment using MJX, and everything works perfectly fine when I use PPO. When using PPO I can easily go to over 8000 environments, with batch size over 1000 without any memory issues. However, no matter what setting I try when running SAC, it becomes a different story (even with a batch size of 32, and 1 environment it fails). I just cannot avoid running into the memory allocation issue copy/pasted below. Any help would be greatly appreciated, because for my particular environment (through trying it using other simulators) it can only really train using the SAC algorithm not PPO, so I really do need to use SAC. Any help would be greatly appreciated.
My computer has a Nvidia 4090 GPU.
Running this training: train_fn = functools.partial( sac.train, num_timesteps=30_000_000, num_evals=1, episode_length=INITIAL_PARAMS.RL_PARAMS.MAX_EPISODE_TIMESTEPS, normalize_observations=True, action_repeat=1, num_envs=1, batch_size = 32, seed=0)
Getting this error: 2024-02-19 15:14:56.882326: W external/xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below -30.29GiB (-32529110990 bytes) by rematerialization; only reduced to 49.62GiB (53280000080 bytes), down from 49.62GiB (53280000080 bytes) originally 2024-02-19 15:15:06.899957: W external/tsl/tsl/framework/bfc_allocator.cc:487] Allocator (GPU_0_bfc) ran out of memory trying to allocate 49.62GiB (rounded to 53280000000)requested by op 2024-02-19 15:15:06.900601: W external/tsl/tsl/framework/bfcallocator.cc:499] ____**_____ E0219 15:15:06.900728 6066 pjrt_stream_executor_client.cc:2766] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 8B constant allocation: 8B maybe_live_out allocation: 49.62GiB preallocated temp allocation: 160B total allocation: 49.62GiB total fragmentation: 168B (0.00%) Peak buffers: Buffer 1: Size: 49.62GiB Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112 XLA Label: fusion Shape: f32[30000000,444]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/home/sclaer/ASCENT_biped/sim/mujoco/rl_mjx.py", line 367, in
make_inferencefn, params, = train_fn(environment=env, progress_fn=progress)
File "/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/agents/sac/train.py", line 429, in train
buffer_state = jax.pmap(replay_buffer.init)(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 8B
constant allocation: 8B
maybe_live_out allocation: 49.62GiB
preallocated temp allocation: 160B
total allocation: 49.62GiB
total fragmentation: 168B (0.00%)
Peak buffers:
Buffer 1:
Size: 49.62GiB
Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112
XLA Label: fusion
Shape: f32[30000000,444]