Closed William-HYWu closed 7 months ago
Hi thanks for the question. You can perhaps try implementing data parallelism, e.g. pmap: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html to split your batches across devices if you are in a multi-gpu setting. Please see also the development branch where there is an example of this: https://github.com/lindermanlab/S5/tree/development
If not, I am not sure what else you could do about the memory usage, (this will be similar for any sequence model) other than perhaps making the model smaller by e.g. reducing d_model.
Hopefully this helps.
Hi thanks for the question. You can perhaps try implementing data parallelism, e.g. pmap: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html to split your batches across devices if you are in a multi-gpu setting. Please see also the development branch where there is an example of this: https://github.com/lindermanlab/S5/tree/development
If not, I am not sure what else you could do about the memory usage, (this will be similar for any sequence model) other than perhaps making the model smaller by e.g. reducing d_model.
Hopefully this helps.
Thanks for replying to me! This does open my mind. I'll try implementing them.
Hi! I'm trying to increasing the batch size on training cifar10 to 1500. However, in this way the GPU will run out of memory, I'm wondering if there's a solution for this since I'm planning on using S5 on tasks that will involve very large input. Here's the configuration (shell script in running experiment)
python run_train.py --C_init=lecun_normal --batchnorm=True --bidirectional=True \ --blocks=3 --bsz=600 --clip_eigs=True --d_model=512 --dataset=lra-cifar-classification \ --epochs=250 --jax_seed=16416 --lr_factor=4.5 --n_layers=6 --opt_config=BfastandCdecay \ --p_dropout=0.1 --ssm_lr_base=0.001 --ssm_size_base=384 --warmup_end=1 --weight_decay=0.07
Here's the full error message:2024-04-06 22:22:12.820520: W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 35.65GiB (38280953856 bytes) by rematerialization; only reduced to 36.48GiB (39170521860 bytes) 2024-04-06 22:22:35.625179: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 37.88GiB (rounded to 40675475712)requested by op Traceback (most recent call last): File "Path/S5/run_train.py", line 101, in <module> train(parser.parse_args()) File "Path/S5/s5/train.py", line 172, in train state, train_loss, step = train_epoch(state, File "Path/S5/s5/train_helpers.py", line 344, in train_epoch state, loss = train_step( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1143, in _pjit_call_impl_python return compiled.unsafe_call(*args), compiled File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1349, in __call__ results = self.xla_executable.execute_sharded(input_bufs) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 40675475656 bytes.