lindermanlab / S5

MIT License
259 stars 45 forks source link

Out of memory when batch size is large #14

Closed William-HYWu closed 7 months ago

William-HYWu commented 7 months ago

Hi! I'm trying to increasing the batch size on training cifar10 to 1500. However, in this way the GPU will run out of memory, I'm wondering if there's a solution for this since I'm planning on using S5 on tasks that will involve very large input. Here's the configuration (shell script in running experiment) python run_train.py --C_init=lecun_normal --batchnorm=True --bidirectional=True \ --blocks=3 --bsz=600 --clip_eigs=True --d_model=512 --dataset=lra-cifar-classification \ --epochs=250 --jax_seed=16416 --lr_factor=4.5 --n_layers=6 --opt_config=BfastandCdecay \ --p_dropout=0.1 --ssm_lr_base=0.001 --ssm_size_base=384 --warmup_end=1 --weight_decay=0.07 Here's the full error message: 2024-04-06 22:22:12.820520: W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 35.65GiB (38280953856 bytes) by rematerialization; only reduced to 36.48GiB (39170521860 bytes) 2024-04-06 22:22:35.625179: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 37.88GiB (rounded to 40675475712)requested by op Traceback (most recent call last): File "Path/S5/run_train.py", line 101, in <module> train(parser.parse_args()) File "Path/S5/s5/train.py", line 172, in train state, train_loss, step = train_epoch(state, File "Path/S5/s5/train_helpers.py", line 344, in train_epoch state, loss = train_step( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1143, in _pjit_call_impl_python return compiled.unsafe_call(*args), compiled File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1349, in __call__ results = self.xla_executable.execute_sharded(input_bufs) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 40675475656 bytes.

jimmysmith1919 commented 7 months ago

Hi thanks for the question. You can perhaps try implementing data parallelism, e.g. pmap: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html to split your batches across devices if you are in a multi-gpu setting. Please see also the development branch where there is an example of this: https://github.com/lindermanlab/S5/tree/development

If not, I am not sure what else you could do about the memory usage, (this will be similar for any sequence model) other than perhaps making the model smaller by e.g. reducing d_model.

Hopefully this helps.

William-HYWu commented 7 months ago

Hi thanks for the question. You can perhaps try implementing data parallelism, e.g. pmap: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html to split your batches across devices if you are in a multi-gpu setting. Please see also the development branch where there is an example of this: https://github.com/lindermanlab/S5/tree/development

If not, I am not sure what else you could do about the memory usage, (this will be similar for any sequence model) other than perhaps making the model smaller by e.g. reducing d_model.

Hopefully this helps.

Thanks for replying to me! This does open my mind. I'll try implementing them.