jzhang38 / EasyContext

Memory optimization and training recipes to extrapolate language models' context length to 1 million tokens, with minimal hardware.
Apache License 2.0
529 stars 33 forks source link

OOM when seq-length=700k #11

Open jkl375 opened 2 months ago

jkl375 commented 2 months ago

Hi, author. When I set seq-length=700k, OOM occured. My torch version is 2.4.0.dev20240324. Do I need to set gradient-accumulate-every to 1?

Max train steps: 90
  0%|                                                                                                                                                                            | 0/90 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py:144: UserWarning: Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. Device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.)
  warnings.warn(
[2024-04-08 12:44:38,013] [WARNING] [stage3.py:2069:step] 9 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  1%|█▊                                                                                                                                                             | 1/90 [16:59<25:12  1%|█▌                                                                                                                                     | 1/90 [16:59<25:12:38, 1019.75s/it, loss=7.13, ppl=1.25e+3][2024-04-08 13:02:16,577] [WARNING] [stage3.py:2069:step] 27 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  2%|███                                                                                                                                    | 2/90 [34:38<25:29:07, 1042.58s/it, loss=7  2%|███                                                                                                                                        | 2/90 [34:38<25:29:07, 1042.58s/it, loss=5.97, ppl=390][2024-04-08 13:19:59,481] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  3%|████▋                                                                                                                                      | 3/90 [52:21<25:25:11, 1051.86s/it, lo  3%|████▋                                                                                                                                      | 3/90 [52:21<25:25:11, 1051.86s/it, loss=5.88, ppl=359][2024-04-08 13:37:46,252] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  4%|██████                                                                                                                                   | 4/90 [1:10:07<25:16:06, 1057.75s/it, lo  4%|██████                                                                                                                                   | 4/90 [1:10:07<25:16:06, 1057.75s/it, loss=5.88, ppl=359][2024-04-08 13:55:34,933] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  6%|███████▌                                                                                                                                 | 5/90 [1:27:56<25:04:03, 1061.69s/it, lo  6%|███████▌                                                                                                                                 | 5/90 [1:27:56<25:04:03, 1061.69s/it, loss=5.71, ppl=301][2024-04-08 14:13:19,565] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  7%|█████████▏                                                                                                                               | 6/90 [1:45:41<24:47:45, 1062.69s/it, lo  7%|█████████▏                                                                                                                               | 6/90 [1:45:41<24:47:45, 1062.69s/it, loss=5.48, ppl=240][2024-04-08 14:30:59,959] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  8%|██████████▋                                                                                                                              | 7/90 [2:03:21<24:29:01, 1061.94s/it, lo  8%|██████████▋                                                                                                                              | 7/90 [2:03:21<24:29:01, 1061.94s/it, loss=5.54, ppl=256][2024-04-08 14:48:45,778] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
  9%|████████████▏                                                                                                                            | 8/90 [2:21:07<24:13:00, 1063.17s/it, lo  9%|████████████▏                                                                                                                            | 8/90 [2:21:07<24:13:00, 1063.17s/it, loss=5.24, ppl=189][2024-04-08 15:06:28,424] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 10%|█████████████▋                                                                                                                           | 9/90 [2:38:50<23:55:03, 1063.01s/it, lo 10%|█████████████▋                                                                                                                           | 9/90 [2:38:50<23:55:03, 1063.01s/it, loss=5.18, ppl=177][2024-04-08 15:24:17,016] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 11%|███████████████                                                                                                                         | 10/90 [2:56:38<23:39:38, 1064.73s/it, lo 11%|███████████████                                                                                                                         | 10/90 [2:56:38<23:39:38, 1064.73s/it, loss=5.08, ppl=161][2024-04-08 15:41:58,421] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 12%|████████████████▌                                                                                                                       | 11/90 [3:14:20<23:20:33, 1063.71s/it, lo 12%|████████████████▌                                                                                                                       | 11/90 [3:14:20<23:20:33, 1063.71s/it, loss=5.01, ppl=150][2024-04-08 15:59:46,293] [WARNING] [stage3.py:2069:step] 31 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 13%|██████████████████▏                                                                                                                     | 12/90 [3:32:08<23:04:28, 1064.98s/it, lo 13%|██████████████████▏                                                                                                                     | 12/90 [3:32:08<23:04:28, 1064.98s/it, loss=4.93, ppl=138][2024-04-08 16:17:37,821] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 14%|███████████████████▋                                                                                                                    | 13/90 [3:49:59<22:49:16, 1066.96s/it, lo 14%|███████████████████▋                                                                                                                    | 13/90 [3:49:59<22:49:16, 1066.96s/it, loss=4.96, ppl=142][2024-04-08 16:35:20,389] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 16%|█████████████████████▏                                                                                                                  | 14/90 [4:07:42<22:29:48, 1065.64s/it, lo 16%|█████████████████████▏                                                                                                                  | 14/90 [4:07:42<22:29:48, 1065.64s/it, loss=5.03, ppl=152][2024-04-08 16:53:01,507] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 17%|██████████████████████▋                                                                                                                 | 15/90 [4:25:23<22:10:20, 1064.27s/it, lo 17%|██████████████████████▋                                                                                                                 | 15/90 [4:25:23<22:10:20, 1064.27s/it, loss=4.85, ppl=127][2024-04-08 17:10:52,587] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 18%|████████████████████████▏                                                                                                               | 16/90 [4:43:14<21:55:07, 1066.32s/it, lo 18%|████████████████████████▏                                                                                                               | 16/90 [4:43:14<21:55:07, 1066.32s/it, loss=4.76, ppl=117][2024-04-08 17:28:38,735] [WARNING] [stage3.py:2069:step] 29 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 19%|█████████████████████████▋                                                                                                              | 17/90 [5:01:00<21:37:17, 1066.27s/it, lo 19%|█████████████████████████▋                                                                                                              | 17/90 [5:01:00<21:37:17, 1066.27s/it, loss=4.91, ppl=135][2024-04-08 17:46:16,972] [WARNING] [stage3.py:2069:step] 28 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 20%|███████████████████████████▏                                                                                                            | 18/90 [5:18:38<21:16:37, 1063.86s/it, lo 20%|███████████████████████████▏                                                                                                            | 18/90 [5:18:38<21:16:37, 1063.86s/it, loss=4.93, ppl=138][2024-04-08 18:04:03,970] [WARNING] [stage3.py:2069:step] 30 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
 21%|████████████████████████████▋                                                                                                           | 19/90 [5:36:25<21:00:00, 1064.80s/it, lo 21%|████████████████████████████▉                                                                                                            | 19/90 [5:36:25<21:00:00, 1064.80s/it, loss=4.9, ppl=134][rank1]: Traceback (most recent call last):
[rank1]:   File "/data/jkl/proj/EasyContext/train.py", line 219, in <module>
[rank1]:     main(args.parse_args())
[rank1]:   File "/data/jkl/proj/EasyContext/train.py", line 138, in main
[rank1]:     accelerator.backward(loss)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/accelerator.py", line 1995, in backward
[rank1]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
[rank1]:     self.engine.backward(loss, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
[rank1]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2213, in backward
[rank1]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank1]:     scaled_loss.backward(retain_graph=retain_graph)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 304, in backward
[rank1]:     outputs = ctx.run_function(*detached_inputs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1577, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/data/jkl/proj/EasyContext/easy_context/zigzag_ring_attn/monkey_patch.py", line 84, in new_decoder_forward
[rank1]:     hidden_states = self.mlp(hidden_states)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1577, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 240, in forward
[rank1]:     down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1577, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
[rank1]:     return F.linear(input, self.weight, self.bias)
[rank1]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.79 GiB. GPU  has a total capacity of 79.15 GiB of which 801.25 MiB is free. Including non-PyTorch memory, this process has 78.33 GiB memory in use. Of the allocated memory 45.74 GiB is allocated by PyTorch, and 31.66 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W0408 18:07:00.108000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735592 closing signal SIGTERM
W0408 18:07:00.109000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735594 closing signal SIGTERM
W0408 18:07:00.109000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735595 closing signal SIGTERM
W0408 18:07:00.109000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735596 closing signal SIGTERM
W0408 18:07:00.110000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735597 closing signal SIGTERM
W0408 18:07:00.110000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735598 closing signal SIGTERM
W0408 18:07:00.110000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 735599 closing signal SIGTERM
E0408 18:07:21.018000 140664182719680 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 1 (pid: 735593) of binary: /data/jkl/miniconda3/envs/easycontext/bin/python
Traceback (most recent call last):
  File "/data/jkl/miniconda3/envs/easycontext/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1042, in launch_command
    deepspeed_launcher(args)
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/accelerate/commands/launch.py", line 754, in deepspeed_launcher
    distrib_run.run(args)
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/data/jkl/miniconda3/envs/easycontext/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-04-08_18:07:00
  host      : ubuntu
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 735593)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
jzhang38 commented 2 months ago

Emm interesting. Honestly, I did not run 700K for this long. I've only run the first 5 steps and call it a day due to my limited compute resources. Yeah I think setting the accumulation to 1 would help.

jkl375 commented 2 months ago

After setting accumulation to 1, there is still appearing oom at step 7. image

jkl375 commented 2 months ago

I wonder if it's because the author of ring flash attn zhuzilin mentioned limits https://github.com/zhuzilin/ring-flash-attention?tab=readme-ov-file#limits

jkl375 commented 2 months ago

Just add PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' before accelerate, so there's no problem training to the 37th step for now. image