NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

fix bug of attn backward in non-casual model with context parallel open. #1031

Closed wplf closed 2 months ago

wplf commented 2 months ago

This bug will cause bug[ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: ~/megatron/bin/python.

That is because we miss the rng_states that is required in attention recompute (for dropout) of backward, but no hint is provided.

It is very very very difficult to trace and cost me two weeks.

before the start of training step] datetime: 2024-07-22 18:26:45 
[2024-07-22 18:27:00,941] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 0 (pid: 1761020) of binary: /home//miniconda3/envs/megatron/bin/python
Traceback (most recent call last):
  File "/home//miniconda3/envs/megatron/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.2.1+cu121', 'console_scripts', 'torchrun')())
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home//miniconda3/envs/megatron/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

ptrendx commented 2 months ago

@xrennvidia @cyanguwa Could you both take a look?

xrennvidia commented 2 months ago

LGTM.

This is a good catch, I indeed missed it. Thanks for the fixing.

ksivaman commented 2 months ago

/te-ci pytorch

cyanguwa commented 2 months ago

LGTM as well.

ptrendx commented 2 months ago

Thank you @wplf for your contribution :-)!