pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
702 stars 84 forks source link

[BUG] cannot capture your model as a full graph #1132

Open sunkun1997 opened 1 month ago

sunkun1997 commented 1 month ago

torch version: 2.5.0.dev20240616+cu121 python version: python 3.8

I run the llama example with torchrun --nproc-per-node 2 pippy_llama.py. It got an Error

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.26s/it]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.27s/it]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
layers_per_rank = 16
layers_per_rank = 16
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1006, in _trace_with_export
[rank0]:     ep = torch.export.export(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/__init__.py", line 174, in export
[rank0]:     return _export(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/_trace.py", line 952, in wrapper
[rank0]:     raise e
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/_trace.py", line 935, in wrapper
[rank0]:     ep = fn(*args, **kwargs)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/exported_program.py", line 91, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/_trace.py", line 1547, in _export
[rank0]:     exported_program = ExportedProgram(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/exported_program.py", line 248, in __init__
[rank0]:     self.verifier().check(self)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/_export/verifier.py", line 154, in check
[rank0]:     self._check_graph_module(ep.graph_module)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/_export/verifier.py", line 220, in _check_graph_module
[rank0]:     _check_val(node)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/_export/verifier.py", line 62, in _check_val
[rank0]:     raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
[rank0]: torch._export.verifier.SpecViolationError: Node.meta _enter_autocast is missing val field.

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "pippy_llama.py", line 36, in <module>
[rank0]:     pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],))
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1236, in pipeline
[rank0]:     return Pipe.from_tracing(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1044, in from_tracing
[rank0]:     exported_program = Pipe._trace_with_export(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1012, in _trace_with_export
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: It seems that we cannot capture your model as a full graph. Typical reasons include graph breaks, data/shape-dependent control flow, or missing meta kernels for custom operators. You can use our manual pipeline interfaces, or try to fix the graph breaks, see https://pytorch.org/docs/stable/export.html
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
apresunreve commented 1 month ago

Same problem

ishan-gaur commented 1 month ago

This can (at least temporarily) be fixed by getting rid of the autocast at transformers/models/llama/modeling_llama.py And replacing everything from the # Force … comment in the forward pass to instead be:

freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)

(this is basically taking the llama model back to commit 7628b3a0f40212c0f264233fc6da0d9c9cf88853 of the transformers package)

However, after doing this, there seems to still be a problem where the compiled (traced?, split?) model graph seems to not match the original:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 3385, in <module>
[rank0]:     main()
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 3378, in main
[rank0]:     globals = debugger.run(setup['file'], None, None, is_module)
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 2446, in run
[rank0]:     return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 2453, in _exec
[rank0]:     pydev_imports.execfile(file, globals, locals)  # execute the script
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydev_bundle/_pydev_execfile.py", line 25, in execfile
[rank0]:     exec(compile(contents + "\n", file, 'exec'), glob, loc)
[rank0]:   File ".../pippy_llama.py", line 45, in <module>
[rank0]:     pipe = pipeline(llama, example_args=(mb_inputs["input_ids"],), num_chunks=int(len(full_batch_prompts) / len(mb_inputs)))
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_IR.py", line 1187, in pipeline
[rank0]:     return Pipe.from_tracing(
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_IR.py", line 1030, in from_tracing
[rank0]:     pipe = Pipe._from_traced(
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_IR.py", line 734, in _from_traced
[rank0]:     new_submod = _outline_submodules(submodule.graph)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_unflatten.py", line 23, in _outline_submodules
[rank0]:     ).run_outer()
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 862, in run_outer
[rank0]:     self.run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 942, in run_from
[rank0]:     ).run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 942, in run_from
[rank0]:     ).run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 942, in run_from
[rank0]:     ).run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 919, in run_from
[rank0]:     self.finalize_outputs()
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 841, in finalize_outputs
[rank0]:     _verify_graph_equivalence(self.cached_graph_module, self.module)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 567, in _verify_graph_equivalence
[rank0]:     assert graph_dump(x.graph) == graph_dump(y.graph)
sunkun1997 commented 1 month ago

Was able to solve the above in my case by turning off the kv-cache in the model config. Perhaps this needs to be manually managed by the user outside of the traced module.

Can you tell me exactly which line you're replacing? And the way to turn off the kv-cache is revicing "use_cache": true into false in config.json?

ishan-gaur commented 1 month ago

Sorry the kv-cache thing was wrong. I was trying out gpt2 earlier to make sure I can at least run something.

Also had the wrong commit number earlier. Was talking about reverting this change in the transformer's library: https://github.com/huggingface/transformers/commit/d45f47ab7f7c31991bb98a0302ded59ab6adac31

ishan-gaur commented 1 month ago

Was able to resolve this by reverting transformers to the last December 2023 commit that passes all tests (3b7675b2b844b02d4821b827871a21ad16dd446c) and the PiPPy v0.2.0 tag. If you need batch chat template decoding then you need to go find the updated utils tokenization base file and the init.py file for that folder accordingly as well.