pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
726 stars 86 forks source link

[BUG] num_stages incorrect and some assertions #1143

Open jq-wei opened 3 weeks ago

jq-wei commented 3 weeks ago

Hi,

First of all, thank you for the great work.

I am trying the llama example script with llama2-7b-hf and the following key packages:

torch                    2.5.0
torchpippy               0.2.0
torchtext                0.6.0
torchview                0.2.6

When I run torchrun --nproc-per-node 4 pippy_llama.py, I got the following error on device 0 :

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank0]:     stage = pipe.build_stage(rank,  device=device)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1150, in build_stage
[rank0]:     return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 799, in __init__
[rank0]:     _PipelineStageBase.__init__(
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 138, in __init__
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: Pipeline group size 4 cannot be larger than number of stages 1

I can trace back to _number_and_count_forward_stages in _IR.py and indeed the num_stages = 1 due to there is only one node.op == "call_module", and all the other node.op == "call_function".

Just for the sake to go deeper, I hard code the return in _number_and_count_forward_stages to be 4. Then I got the following error

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank0]:     stage = pipe.build_stage(rank,  device=device)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1150, in build_stage
[rank0]:     return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 816, in __init__
[rank0]:     raise AssertionError(
[rank0]: AssertionError: Number of submodules in pipe graph 1 does not match number of stages 4
[rank2]: Traceback (most recent call last):
[rank2]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank2]:     stage = pipe.build_stage(rank,  device=device)
[rank2]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank2]:     stage_module = self.get_stage_module(stage_index)
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank2]:     return getattr(self.split_gm, f"submod_{stage_idx}")
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank2]:     raise AttributeError(
[rank2]: AttributeError: 'GraphModule' object has no attribute 'submod_2'. Did you mean: 'submod_0'?
[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank1]:     stage = pipe.build_stage(rank,  device=device)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank1]:     stage_module = self.get_stage_module(stage_index)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank1]:     return getattr(self.split_gm, f"submod_{stage_idx}")
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank1]:     raise AttributeError(
[rank1]: AttributeError: 'GraphModule' object has no attribute 'submod_1'. Did you mean: 'submod_0'?
[rank3]: Traceback (most recent call last):
[rank3]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank3]:     stage = pipe.build_stage(rank,  device=device)
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank3]:     stage_module = self.get_stage_module(stage_index)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank3]:     return getattr(self.split_gm, f"submod_{stage_idx}")
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank3]:     raise AttributeError(
[rank3]: AttributeError: 'GraphModule' object has no attribute 'submod_3'. Did you mean: 'submod_0'?

It seems the version matching problem is still there. By the way, the same problems happen if I uninstall torchpippy.

Could you give me some hints?

Thank you very much!

Noblezhong commented 2 weeks ago

I also have this problem when run this scripts on a server with multi GPUs. But I can run successfully for GPT example. Do u have any solution for this? It seems that this code can not running successfully in single server environment