jzhang38 / EasyContext

Memory optimization and training recipes to extrapolate language models' context length to 1 million tokens, with minimal hardware.
Apache License 2.0
650 stars 47 forks source link

RuntimeError: CUDA error: an illegal memory access was encountered #46

Open uditsharma7 opened 3 months ago

uditsharma7 commented 3 months ago

I am facing this issue while using zigzag_ring_attn with 128k context length. Has anyone run into the same problem?

[rank0]:   File "/app/c2j-long-context-model-training/EasyContext/easy_context/zigzag_ring_attn/monkey_patch.py", line 69, in new_decoder_forward
[rank0]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 469, in forward
[rank0]:     attn_output = self._flash_attention_forward(
[rank0]:   File "/app/c2j-long-context-model-training/EasyContext/easy_context/zigzag_ring_attn/monkey_patch.py", line 29, in new_flash_attn_forward
[rank0]:     attn_output = zigzag_ring_flash_attn_func(
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/ring_flash_attn/zigzag_ring_flash_attn.py", line 312, in zigzag_ring_flash_attn_func
[rank0]:     return ZigZagRingFlashAttnFunc.apply(
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/ring_flash_attn/zigzag_ring_flash_attn.py", line 202, in forward
[rank0]:     out, softmax_lse = zigzag_ring_flash_attn_forward(
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/ring_flash_attn/zigzag_ring_flash_attn.py", line 59, in zigzag_ring_flash_attn_forward
[rank0]:     out, lse = update_out_and_lse(
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/ring_flash_attn/utils.py", line 44, in update_out_and_lse
[rank0]:     slice_out, slice_lse = _update_out_and_lse(
[rank0]: RuntimeError: The following operation failed in the TorchScript interpreter.
[rank0]: Traceback of TorchScript (most recent call last):
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/ring_flash_attn/utils.py", line 24, in _update_out_and_lse
[rank0]:     # For additional context and discussion, please refer to:
[rank0]:     # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
[rank0]:     out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank0]:                 ~~~~~~~~~ <--- HERE
[rank0]:     lse = lse - F.logsigmoid(lse - block_lse)
[rank0]:   File "/opt/conda/envs/ai/lib/python3.10/site-packages/torch/nn/functional.py", line 2013, in sigmoid
[rank0]:     See :class:`~torch.nn.Sigmoid` for more details.
[rank0]:     """
[rank0]:     return input.sigmoid()
[rank0]:            ~~~~~~~~~~~~~ <--- HERE
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
uditsharma7 commented 2 months ago

With smaller context length I got

W0904 15:27:23.932000 23280343537216 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 261407 closing signal SIGTERM
W0904 15:27:23.932000 23280343537216 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 261408 closing signal SIGTERM
W0904 15:27:23.932000 23280343537216 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 261410 closing signal SIGTERM
E0904 15:27:34.716000 23280343537216 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: -11) local_rank: 2 (pid: 261409) of binary: /dccstor/udit/env/easy_context/bin/python
Traceback (most recent call last):
  File "/udit/env/easy_context/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/udit/env/easy_context/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/udit/env/easy_context/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1091, in launch_command
    deepspeed_launcher(args)
  File "/udit/env/easy_context/lib/python3.10/site-packages/accelerate/commands/launch.py", line 787, in deepspeed_launcher
    distrib_run.run(args)
  File "/dccstor/udit/env/easy_context/lib/python3.10/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/udit/env/easy_context/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/udit/env/easy_context/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
========================================================
EasyContext/train.py FAILED
--------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-09-04_15:27:23
  host      : cccxc614.pok.ibm.com
  rank      : 2 (local_rank: 2)
  exitcode  : -11 (pid: 261409)
  error_file: <N/A>
  traceback : Signal 11 (SIGSEGV) received by PID 261409
========================================================