sayakpaul / diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Apache License 2.0
271 stars 8 forks source link

Using quantize_to_float8 results in "Only support dtype kwarg for autocast" #9

Closed a-r-r-o-w closed 3 months ago

a-r-r-o-w commented 3 months ago

It seems that the device argument cannot be passed whether you pass it as an arg or kwarg. Any workarounds or suggestions if I'm doing something wrong?

Reproducer ```python import torch from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler from diffusers.utils import export_to_video from transformers import T5EncoderModel from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8 model_id = "THUDM/CogVideoX-2b" device = "cuda" # 1. Load models text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) transformer = CogVideoXTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ) vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) # 2. Quantize transformer = quantize_to_float8(transformer, QuantConfig(ActivationCasting.DYNAMIC)) # 3. Load pipeline pipe = CogVideoXPipeline.from_pretrained( model_id, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16, ) pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.set_progress_bar_config(disable=True) pipe.to(device=device) # <--- Fails here prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " "atmosphere of this unique musical performance." ) video = pipe( prompt=prompt, guidance_scale=6, num_inference_steps=50, generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203 ) export_to_video(video.frames[0], "output.mp4", fps=8) ```
Traceback ```python Traceback (most recent call last): File "/home/aryan/work/diffusers/workflows/experiments/cogvideox-torchao/reproducer.py", line 35, in pipe.to(device=device) File "/home/aryan/work/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 435, in to module.to(device, dtype) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to return self._apply(convert) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply module._apply(fn) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply module._apply(fn) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 927, in _apply param_applied = fn(param) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1326, in convert return t.to( File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torchao/float8/float8_tensor.py", line 359, in __torch_dispatch__ return FLOAT8_OPS_TABLE[func](func, args, kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torchao/float8/float8_ops.py", line 244, in autocast_to_copy len(kwargs) == 1 and "dtype" in kwargs AssertionError: Only support dtype kwarg for autocast ```

cc @sayakpaul @jerryzh168

a-r-r-o-w commented 3 months ago

Surprisingly enough, moving the pipeline to cuda first, and then performing fp8 quantization fixes the above issue. However, I now get the following error:

Traceback ``` File "/home/aryan/work/diffusers/workflows/experiments/cogvideox-torchao/compile.py", line 210, in main(args.dtype, args.device, args.quantize_vae, args.compile, args.fuse_qkv) File "/home/aryan/work/diffusers/workflows/experiments/cogvideox-torchao/compile.py", line 150, in main video = run_inference(pipe) File "/home/aryan/work/diffusers/workflows/experiments/cogvideox-torchao/compile.py", line 124, in run_inference video = pipe( File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/aryan/work/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 629, in __call__ prompt_embeds, negative_prompt_embeds = self.encode_prompt( File "/home/aryan/work/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 297, in encode_prompt prompt_embeds = self._get_t5_prompt_embeds( File "/home/aryan/work/diffusers/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 240, in _get_t5_prompt_embeds prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1971, in forward encoder_outputs = self.encoder( File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1106, in forward layer_outputs = layer_module( File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 686, in forward self_attention_outputs = self.layer[0]( File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 593, in forward attention_output = self.SelfAttention( File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 512, in forward query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torchao/float8/inference.py", line 112, in forward return torch.nn.functional.linear(x_fp8, self.weight, self.bias) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torchao/float8/float8_tensor.py", line 359, in __torch_dispatch__ return FLOAT8_OPS_TABLE[func](func, args, kwargs) File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torchao/float8/float8_ops.py", line 181, in float8_mm tensor_out = addmm_float8_unwrapped( File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torchao/float8/float8_python_api.py", line 54, in addmm_float8_unwrapped output = torch._scaled_mm( RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+ ```

This makes sense because the machines I'm using have CUDA capability 8 and lower. @sayakpaul Would it be possible to run these benchmarks on a different machine, or do we just skip the ones that fail?

sayakpaul commented 3 months ago

Hmm, an easier solution would be to test it on an H100.

a-r-r-o-w commented 3 months ago

It was indeed the case that A100 didn't support the fp8 quantization here (having cuda capability lower than required as mentioned in the error). Works perfectly fine on an H100