pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.3k stars 22.14k forks source link

[Compile] Running Llama2 with torch.compile and FSDP results in Type mismatch assert in LlamaRotaryEmbedding #108211

Open lessw2020 opened 1 year ago

lessw2020 commented 1 year ago

🐛 Describe the bug

Running Torch.compile with Llama7B and FSDP mixed precision, results in assert during first forward pass of training: (you can repro by going to https://github.com/lessw2020/llama-recipes/tree/rotary_embeddings and run "bash run.sh")

    assert a == b, f"{a} != {b}"
AssertionError: torch.float32 != torch.bfloat16

from this section (full trace below):

   cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 123, in forward
    self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),

Effectively there is a type mismatch but at least in adding some debugging to the Rotary cache and the incoming tensors, everything is all fp32.

Here's the full stack trace:


Training Epoch0:   0%|                                                                                                 | 0/48 [01:11<?, ?it/s]
Traceback (most recent call last):
  File "/data/home/less/llama_rotary/llama_finetuning.py", line 262, in <module>
    fire.Fire(main)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/data/home/less/llama_rotary/llama_finetuning.py", line 245, in main
    results = train(
  File "/data/home/less/llama_rotary/utils/train_utils.py", line 92, in train
    loss = model(**batch).loss
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 807, in forward
    outputs = self.model(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 694, in forward
    layer_outputs = decoder_layer(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 409, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/optimum/bettertransformer/models/decoder_models.py", line 387, in forward
    return llama_forward(self, *args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 616, in llama_forward
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 122, in forward
    print(f"{x.dtype=}")
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 488, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 625, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 139, in _fn
    return fn(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 555, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 477, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 444, in transform
    tracer.run()
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1191, in LOAD_ATTR
    result = BuiltinVariable(getattr).call_function(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builtin.py", line 1116, in call_getattr
    obj.var_getattr(tx, name).clone(source=source).add_options(options)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/user_defined.py", line 482, in var_getattr
    return VariableBuilder(tx, source)(subobj).add_options(options)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 223, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 368, in _wrap
    return type_dispatch(self, value)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 879, in wrap_tensor
    return self.tx.output.register_attr_or_module(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 728, in register_attr_or_module
    return wrap_name(name)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 634, in wrap_name
    return wrap_fx_proxy(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1187, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1302, in wrap_fx_proxy_cls
    example_value = wrap_to_fake_tensor_and_record(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1583, in wrap_to_fake_tensor_and_record
    fake_e = wrap_fake_exception(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 916, in wrap_fake_exception
    return fn()
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1584, in <lambda>
    lambda: tx.fake_mode.from_tensor(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1720, in from_tensor
    return self.fake_tensor_converter(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 371, in __call__
    return self.from_real_tensor(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 324, in from_real_tensor
    out = self.meta_converter(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 595, in __call__
    r = self.meta_tensor(
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 493, in meta_tensor
    assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 79, in assert_metadata_eq
    return go(m1, m2)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 74, in go
    go(m1._base, m2._base)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 51, in go
    assert_eq(m1.dtype, m2.dtype)
  File "/data/home/less/miniconda3/lib/python3.9/site-packages/torch/_subclasses/meta_utils.py", line 46, in assert_eq
    assert a == b, f"{a} != {b}"
AssertionError: torch.float32 != torch.bfloat16

from user code:
   File "/data/home/less/miniconda3/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 124, in <resume in forward>
    self.cos_cached[:, :, :seq_len, ...].to(dtype=torch.bfloat16), # x.dtype),

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Versions

Collecting environment information... PyTorch version: 2.1.0.dev20230825+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: Could not collect CMake version: version 3.26.4 Libc version: glibc-2.31

Python version: 3.9.12 (main, Apr 5 2022, 06:56:58) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.15.0-1038-aws-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB GPU 1: NVIDIA A100-SXM4-40GB GPU 2: NVIDIA A100-SXM4-40GB GPU 3: NVIDIA A100-SXM4-40GB GPU 4: NVIDIA A100-SXM4-40GB GPU 5: NVIDIA A100-SXM4-40GB GPU 6: NVIDIA A100-SXM4-40GB GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 525.85.12 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 96 On-line CPU(s) list: 0-95 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz Stepping: 7 CPU MHz: 1250.736 BogoMIPS: 5999.99 Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.5 MiB L1i cache: 1.5 MiB L2 cache: 48 MiB L3 cache: 71.5 MiB NUMA node0 CPU(s): 0-23,48-71 NUMA node1 CPU(s): 24-47,72-95 Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.1 [pip3] pytorch-triton==2.1.0+e6216047b8 [pip3] st-moe-pytorch==0.0.22 [pip3] torch==2.1.0.dev20230825+cu121 [pip3] torchaudio==2.1.0.dev20230825+cu121 [pip3] torchinfo==1.8.0 [pip3] torchvision==0.16.0.dev20230825+cu121 [pip3] vit-pytorch==1.4.1 [conda] numpy 1.24.1 pypi_0 pypi [conda] pytorch-triton 2.1.0+e6216047b8 pypi_0 pypi [conda] st-moe-pytorch 0.0.22 pypi_0 pypi [conda] torch 2.1.0.dev20230825+cu121 pypi_0 pypi [conda] torchaudio 2.1.0.dev20230825+cu121 pypi_0 pypi [conda] torchinfo 1.8.0 pypi_0 pypi [conda] torchvision 0.16.0.dev20230825+cu121 pypi_0 pypi [conda] vit-pytorch 1.4.1 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @msaroufim @bdhirsh @anijain2305 @kiukchung @d4l3k @lucasllc

msaroufim commented 1 year ago

FWIW I couldn't repro this by just compiling the LlamaRotaryEmbedding from HuggingFace on a single GPU without FSDP

lessw2020 commented 1 year ago

I think the core issue is the mixed precision aspect of FSDP and how compile is interacting with it. Per IBM, if you move your weights to pure BF16 (so no mixed precision) then this issue goes away (though they are then reporting it errors out with a stride mismatch..but we'll get to that after this is resolved).

ezyang commented 1 year ago

Horace: There's another issue with DDP. Is there anyone signed up to own FSDP + torch.compile / DDP + torch.compile? This is basically @voznesenskym

jon-chuang commented 10 months ago

Is there anyone signed up to own FSDP + torch.compile / DDP + torch.compile

I'm currently working on this.

voznesenskym commented 10 months ago

I am working on compile + FSDP. Tracked w/ meta internal posts.

jon-chuang commented 10 months ago

Sounds good @voznesenskym we could sync later on this as needed.

yf225 commented 6 months ago

We are working on compile + FSDP which is preferred over graph-break FSDP. We aim to have it ready at "prototype" release stage by end of H1.

shawndx commented 2 months ago

Ran into the same issue, will this be fixed in an incoming release? And, wondering if there is any workaround fix? Thanks.