sayakpaul / diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Apache License 2.0
271 stars 8 forks source link

Try out FP6 #4

Closed sayakpaul closed 3 months ago

sayakpaul commented 3 months ago

Error:

Traceback (most recent call last):
  File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 165, in <module>
    pipeline = load_pipeline(
  File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 81, in load_pipeline
    quantize_(pipeline.transformer, fp6_llm_weight_only())
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/quantization/quant_api.py", line 323, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/quantization/quant_api.py", line 175, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/quantization/quant_api.py", line 175, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/quantization/quant_api.py", line 175, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  [Previous line repeated 1 more time]
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/quantization/quant_api.py", line 171, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/quantization/quant_api.py", line 262, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/quantization/quant_api.py", line 262, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
AttributeError: 'Parameter' object has no attribute 'weight'

The PixArt model has an nn.Parameter: https://github.com/huggingface/diffusers/blob/a57a7af45cbef004c38e2a294a6457f7f3574e5d/src/diffusers/models/transformers/pixart_transformer_2d.py#L174

Is there way to provide some filtering to fp6_llm_weight_only() so that it doesn't pick up nn.Parameters?

sayakpaul commented 3 months ago

@jerryzh168 could you advice?

msaroufim commented 3 months ago

quantize_ optionally takes in a filter_fn https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L267

jerryzh168 commented 3 months ago

this does look weird because we already filter linear by default, let me try to repro

jerryzh168 commented 3 months ago

@sayakpaul python3 benchmark_pixart.py --compile --quantization fp6 runs for me, can you paste the repro command?

sayakpaul commented 3 months ago

Strange python3 benchmark_pixart.py --compile --quantization fp6 is failing for me with the same error. LMK try with the latest torchao.

sayakpaul commented 3 months ago

Working with with the latest torchao.