Closed gradjitta closed 2 months ago
What is your GPU? Installing torchao from source helps?
I shared my env above and its H100 80GB HBM3
and also I built torchao from source.
And looks like I dont yet see any error when I do the following
quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
Edit: seems to work with this Takes around 14m 20 secs for the compile and after that its 9.98 it/s from 4.16 it/s
Nice.
For autoquant, could you maybe try using the benchmark_image.py script?
Oh with the benchmark script it goes through and I am using this
python3 benchmark_image.py --compile --quantization autoquant --batch_size 1
Maybe the difference is the class FluxPipeline vs DiffusionPipeline? (EDIT: I can rerun the benchmark script with FluxPipeline and autotune )
ckpt_id | batch_size | fuse | compile | compile_vae | quantization | sparsify | model_memory | inference_memory | time |
---|---|---|---|---|---|---|---|---|---|
black-forest-labs/FLUX.1-dev | 1 | False | True | False | autoquant | False | 31.438 | 32.461 | 3.407 |
This order results in error
pipeline = FluxPipeline.from_pretrained(PATH_TO_DEV, torch_dtype=torch.bfloat16).to("cuda")
pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
where as this doesnt
pipeline = FluxPipeline.from_pretrained(PATH_TO_DEV, torch_dtype=torch.bfloat16).to("cuda")
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)
``
Interesting. But for DiffusionPipeline this doesn’t happen?
It's the same for both DiffusionPipeline and FluxPipeline. At this point, I think it's the order in which torch.compile is called.
Oh yes, order matters a lot. For autoquant and quantize_()
it's different. We should follow the order from benchmark_image.py
.
Closing this issue then.
I have also made it clear what the order of autoquant and torch.compile()
should be:
https://github.com/sayakpaul/diffusers-torchao/commit/ec305cf7c37b7eb52a922cbba9336baf881588ab
Hi @sayakpaul! Thanks for all your work related to Flux and diffusers. I came across your post on X about faster flux
I have been following the snippet you shared there, but I came across an error during compile with torchao
The warmup triggers the error below
Compile error
```python { "name": "InternalTorchDynamoError", "message": "TypeError: _make_wrapper_subclass(): argument 'dtype' must be torch.dtype, not torch._C._TensorMeta from user code: File \"/home/mlops/flux-stuff/ao/torchao/quantization/autoquant.py\", line 651, in autoquant_prehook real_model.forward(*args, **kwargs) File \"/home/mlops/flux-stuff/diffusers/src/diffusers/models/transformers/transformer_flux.py\", line 442, in forward hidden_states = self.x_embedder(hidden_states) File \"/home/mlops/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/linear.py\", line 125, in forward return F.linear(input, self.weight, self.bias) Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True ", "stack": "--------------------------------------------------------------------------- InternalTorchDynamoError Traceback (most recent call last) Cell In[7], line 3 1 #warmup 2 for _ in range(3): ----> 3 _ = pipeline(\"a forest\", num_inference_steps=30, guidance_scale=3.5) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.And the following is my env
My env
```shell Collecting environment information... PyTorch version: 2.5.0.dev20240906 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35 Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.3.107 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3 Nvidia driver version: 550.90.07 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 30 On-line CPU(s) list: 0-29 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8462Y+ CPU family: 6 Model: 143 Thread(s) per core: 1 Core(s) per socket: 1 Socket(s): 30 Stepping: 8 BogoMIPS: 5600.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 960 KiB (30 instances) L1i cache: 960 KiB (30 instances) L2 cache: 120 MiB (30 instances) L3 cache: 480 MiB (30 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-29 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Unknown: No mitigations Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled Versions of relevant libraries: [pip3] numpy==2.0.1 [pip3] torch==2.5.0.dev20240906 [pip3] torchao==0.6.0+gitc6abf2bd [pip3] torchaudio==2.5.0.dev20240907 [pip3] torchvision==0.20.0.dev20240907 [pip3] triton==3.0.0 [conda] blas 1.0 mkl [conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly [conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly [conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly [conda] filelock 3.9.0 py311_0 pytorch-nightly [conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly [conda] mkl 2023.1.0 h213fc3f_46344 [conda] mkl-service 2.4.0 py311h5eee18b_1 [conda] mkl_fft 1.3.10 py311h5eee18b_0 [conda] mkl_random 1.2.7 py311ha02d727_0 [conda] mpmath 1.2.1 py311_0 pytorch-nightly [conda] numpy 2.0.1 py311h08b1b3b_1 [conda] numpy-base 2.0.1 py311hf175353_1 [conda] pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly [conda] pysocks 1.7.1 py311_0 pytorch-nightly [conda] pytorch 2.5.0.dev20240906 py3.11_cuda12.4_cudnn9.1.0_0 pytorch-nightly [conda] pytorch-cuda 12.4 hc786d27_7 pytorch-nightly [conda] pytorch-mutex 1.0 cuda pytorch-nightly [conda] requests 2.28.1 py311_0 pytorch-nightly [conda] torchao 0.6.0+gitc6abf2bd dev_0When I dont use autoquant the compilation goes through.
Maybe I am missing something trivial or could you point me to the correct torchao commit. Cheers!