pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.63k stars 181 forks source link

torchao.float8 + torch.compile does not work on HuggingFace's Mixtral model #1200

Open vkuzo opened 1 month ago

vkuzo commented 1 month ago

🐛 Describe the bug

Specifically, if we try to compile a float8 version of a FFN expert (MixtralBlockSparseTop2MLP), we see shape errors

Script (requires torchao and transformers):

from transformers.models.mixtral.modeling_mixtral import (                 
    MixtralConfig,                                                         
    MixtralSparseMoeBlock,                                                 
    MixtralBlockSparseTop2MLP,                                             
)                                                                          

import torch                                                               

from torchao.float8 import convert_to_float8_training                      

config = MixtralConfig()                                                   
m = MixtralSparseMoeBlock(config).cuda()                                   

def module_filter_fn(mod, fqn):                                            
    # gate has out_channels == 8 which is not compatible with float8 gemms,
    # filter it out from float8 for now                                    
    return "gate" not in fqn                                               

convert_to_float8_training(m, module_filter_fn=module_filter_fn)           

# if we compile the whole module with float8, see this:                    
# https://gist.github.com/vkuzo/bc001fc1609bac361e6e13995c23d808           
# m = torch.compile(m)                                                     

# try compiling individual experts                                         
for name, child in m.experts.named_children():                             
    if isinstance(child, MixtralBlockSparseTop2MLP):                       
        new_child = torch.compile(child)                                   
        setattr(m.experts, name, new_child)                                

# we should see each individual expert being compiled here                 
print(m)                                                                   

# if torch.compile is on, this will fail with                              
# https://gist.github.com/vkuzo/b5136f21302cd2a259cbb37cda1aa717           
for _ in range(10):                                                        
    x = torch.randn(1, 4096, 4096, device='cuda').requires_grad_()         
    final_hidden_states, router_logits = m(x)                              
    # res = final_hidden_states.sum() + router_logits.sum()                
    # res.backward()                                                       

print('done')                                                              

Output:

...
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1254, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1796, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1355, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2155, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_meta_registrations.py", line 5558, in meta_scaled_mm
    torch._check(
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/__init__.py", line 1573, in _check     
    _check_with(RuntimeError, cond, message)
  File "/home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/__init__.py", line 1555, in _check_with
    raise error_type(message_evaluated)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Expected self.size(1) to be divisible by 16, but got self.size(1)=1007

Full output: https://gist.github.com/vkuzo/b5136f21302cd2a259cbb37cda1aa717

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241023+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-2)
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.34

Python version: 3.11.0 (main, Mar  1 2023, 18:26:19) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk14_zion_2601_gcd42476b84e9-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
GPU 4: NVIDIA H100
GPU 5: NVIDIA H100
GPU 6: NVIDIA H100
GPU 7: NVIDIA H100

Nvidia driver version: 535.154.05
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.8.0
/usr/lib64/libcudnn_adv_infer.so.8.8.0
/usr/lib64/libcudnn_adv_train.so.8.8.0
/usr/lib64/libcudnn_cnn_infer.so.8.8.0
/usr/lib64/libcudnn_cnn_train.so.8.8.0
/usr/lib64/libcudnn_ops_infer.so.8.8.0
/usr/lib64/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             384
On-line CPU(s) list:                0-383
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 9654 96-Core Processor
CPU family:                         25
Model:                              17
Thread(s) per core:                 2
Core(s) per socket:                 96
Socket(s):                          2
Stepping:                           1
Frequency boost:                    enabled
CPU(s) scaling MHz:                 79%
CPU max MHz:                        3707.8120
CPU min MHz:                        1500.0000
BogoMIPS:                           4792.82
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization:                     AMD-V
L1d cache:                          6 MiB (192 instances)
L1i cache:                          6 MiB (192 instances)
L2 cache:                           192 MiB (192 instances)
L3 cache:                           768 MiB (24 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-95,192-287
NUMA node1 CPU(s):                  96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Vulnerable: eIBRS with unprivileged eBPF
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pytorch-lightning==2.4.0
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241023+cu121
[pip3] torchao==0.7.0+gitd2526126
[pip3] torchdata==0.8.0
[pip3] torchmetrics==1.5.0
[pip3] torchvision==0.20.0.dev20241023+cu121
[pip3] triton==3.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
[conda] pytorch-lightning         2.4.0                    pypi_0    pypi
[conda] pytorch-triton            3.1.0+cf34004b8a          pypi_0    pypi
[conda] torch                     2.6.0.dev20241023+cu121          pypi_0    pypi
[conda] torchao                   0.7.0+gitd2526126           dev_0    <develop>
[conda] torchdata                 0.8.0                    pypi_0    pypi
[conda] torchmetrics              1.5.0                    pypi_0    pypi
[conda] torchvision               0.20.0.dev20241023+cu121          pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi

cc @ezyang @chauhang @penguinwu

vkuzo commented 1 month ago

Update:

  1. MoE model sends tensors with varying dims to the individual experts, so the failue is expected in the absence of padding

  2. We can enable padding with the following update to the float8 config:

config = Float8LinearConfig(pad_inner_dim=True)                                
convert_to_float8_training(m, module_filter_fn=module_filter_fn, config=config)

Full example: https://gist.github.com/vkuzo/f5cd488ab635ea3dfe1205aa68eca473

  1. Unfortunately, the above script does not work in triton 3.1.0 due a triton compilation error when padding f8e5m3. This is fixed in https://github.com/triton-lang/triton/pull/4222, which is not included in triton 3.1.0. We can install the triton nightly to get around this error:
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly

After installing the triton nightly and enabling float8 padding, we can e2e compile this example.

Keeping this issue open to track documenting this in our README.md file.

vkuzo commented 1 month ago

We should also add performance benchmarks on float8 + compile on MoE, as padding will have an additional overhead

qingquansong commented 1 month ago

Thank you very much! @vkuzo Would be awesome to have benchmarking comparison on this. And just to confirm, is this for FSDP1 or FSDP2 version?

vkuzo commented 1 month ago

I wrote this issue to be independent from choice of FSDP version and the localized repro does not use FSDP at all, I really hope that it's orthogonal :) If there are any issues when composing this with FSDP1 or 2, happy to help figure those out.

vkuzo commented 1 month ago

an additional thing: in my original repro for this issue, things worked in eager mode but not in compile, for forward pass only. For fwd+bwd, both eager and compile were broken. I suspect that torch._scaled_mm's meta function might have shape constraints which are too stringent compared to the actual implementation of the scaled gemm, which would explain eager working but compile not working for the forward. Noting it here so I don't forget, but we should ensure eager matches compile here. TODO file an issue / fix this in pytorch/pytorch.

bdhirsh commented 1 month ago

oh @vkuzo this errors in eager mode for me too actually, if I just uncomment the lines that run the backward too:

res = final_hidden_states.sum() + router_logits.sum()
res.backward()

What's going on here is that:

(1) in eager mode, you won't actually see the shape error until your execute your backward

(2) under compile, we generate a backward graph ahead of time when we see that your forward outputs require gradients

(3) that means that instead of getting a delayed error at the time your run the backward, you'll get the error earlier (at compile time) when we trace out the backward graph

If you really want to just run the forward and not the backward under compile, you can run under no_grad (which will be more efficient anyway as we don't need to save activations in the generated compile graph.

So it seems like padding might be required here for both eager and compile?

vkuzo commented 1 month ago

I see, thanks for the explanation @bdhirsh !