pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.49k stars 149 forks source link

autoquant api sharp edges #657

Open msaroufim opened 2 months ago

msaroufim commented 2 months ago

Context

I was trying to run the new Flux model but ran into some sharp bits with the autoquant API

import time
import torchao
from torchao.quantization.quant_api import quantize_, int8_weight_only
from torch import nn
from torch.utils.benchmark import Timer

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

def inference():
    prompt = "A cat holding a sign that says hello world"
    out = pipe(
        prompt=prompt,
        guidance_scale=0.,
        height=768,
        width=1360,
        num_inference_steps=4,
        max_sequence_length=256,
    ).images[0]
    out.save("image.png")

tic = time.time()
inference()
toc = time.time()

print(f"Original Running time is {toc - tic}")

What I tried

The baseline was compiling the model which made it about 25% faster

pipe = torch.compile(pipe)

So I tried compiling autoquant at first like this

# Running over the pipe didn't work
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 175, in _replace_with_custom_fn_if_matches_filter
#     for name, child in model.named_children():
# AttributeError: 'function' object has no attribute 'named_children'
# pipe = torch.autoquant(torch.compile(pipe, mode='max-autotune'))

And the problem was while torch.compile works with both nn modules and functions. torchao.autoquant works only with nn modules

So instead I tried running over the transformer of the pipe only

# Running over transformer only
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 783, in compute_should_use_set_data
#     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# TypeError: _has_compatible_shallow_copy_type(): argument 'from' (position 2) must be Tensor, not NoneType
# pipe.transformer = torchao.autoquant(torch.compile(pipe.transformer, mode="max-autotune"))

Another idea was to compile the pipe and autoquant the transformer

#   return t.type.__tensor_unflatten__(
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/autoquant.py", line 177, in __tensor_unflatten__
#     return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/autoquant.py", line 65, in __new__
#     return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)  # type: ignore[attr-defined]
# torch._dynamo.exc.InternalTorchDynamoError: _make_wrapper_subclass(): argument 'dtype' must be torch.dtype, not torch._C._TensorMeta

pipe.transformer = torchao.autoquant(pipe.transformer)

Dependencies

pip freeze accelerate==0.33.0 aiofiles==23.2.1 altair==5.4.0 annotated-types==0.7.0 anyio==4.4.0 attrs==24.2.0 blinker==1.8.2 cachetools==5.4.0 certifi==2024.7.4 charset-normalizer==3.3.2 click==8.1.7 contourpy==1.2.1 cycler==0.12.1 diffusers==0.30.0 einops==0.8.0 exceptiongroup==1.2.2 fastapi==0.112.0 ffmpy==0.4.0 filelock==3.13.1 fire==0.6.0 -e git+https://github.com/black-forest-labs/flux@c23ae247225daba30fbd56058d247cc1b1fc20a3#egg=flux fonttools==4.53.1 fsspec==2024.6.1 gitdb==4.0.11 GitPython==3.1.43 gradio==4.41.0 gradio_client==1.3.0 h11==0.14.0 httpcore==1.0.5 httpx==0.27.0 huggingface-hub==0.24.5 idna==3.7 importlib_metadata==8.2.0 importlib_resources==6.4.0 invisible-watermark==0.2.0 Jinja2==3.1.4 jsonschema==4.23.0 jsonschema-specifications==2023.12.1 kiwisolver==1.4.5 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.1.post1 mdurl==0.1.2 mpmath==1.3.0 narwhals==1.3.0 networkx==3.3 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.1.105 nvidia-nvtx-cu12==12.1.105 opencv-python==4.10.0.84 orjson==3.10.7 packaging==24.1 pandas==2.2.2 pillow==10.4.0 protobuf==5.27.3 psutil==6.0.0 pyarrow==17.0.0 pydantic==2.8.2 pydantic_core==2.20.1 pydeck==0.9.1 pydub==0.25.1 Pygments==2.18.0 pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-multipart==0.0.9 pytorch-triton==3.0.0+dedb7bdf33 pytz==2024.1 PyWavelets==1.6.0 PyYAML==6.0.2 referencing==0.35.1 regex==2024.7.24 requests==2.32.3 rich==13.7.1 rpds-py==0.20.0 ruff==0.5.7 safetensors==0.4.4 semantic-version==2.10.0 sentencepiece==0.2.0 shellingham==1.5.4 six==1.16.0 smmap==5.0.1 sniffio==1.3.1 starlette==0.37.2 streamlit==1.37.1 streamlit-keyup==0.2.4 sympy==1.13.1 tenacity==8.5.0 termcolor==2.4.0 tokenizers==0.19.1 toml==0.10.2 tomlkit==0.12.0 torch==2.5.0.dev20240811+cu121 torchao @ file:///home/marksaroufim/ao torchvision==0.19.0 tornado==6.4.1 tqdm==4.66.5 transformers==4.44.0 triton==3.0.0 typer==0.12.3 typing_extensions==4.12.2 tzdata==2024.1 urllib3==2.2.2 uvicorn==0.30.5 watchdog==4.0.2 websockets==12.0 zipp==3.20.0
HDCharles commented 2 months ago

looked into this, the problem isn't autoquant, the model can't be quantized at all seemingly, it has a very weird structure with prehooks that don't seem to play well with quantized tensor subclasses

i can't actually figure out what the issue is

it seems in the prehook in pipe.transformer (e.g. hooks.py lines 161, 690, 691 in the trace), has a .to method that causes some weirdness to happen. .to on the affine_quantized_tensor calls .to on the layout tensor and.... eventually a PlainAQTLayout is created with no attributes but it happens in the trace between the

File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 299, in to self.layout_tensor.to(device),

and

File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/lazy.py", line 29, in realize self.vt = VariableBuilder(tx, self.source)(self.value)

lines. I breakpointed to try to find the point where the change happens but i was unable to, the first is a normal layout_tensor, the second is a layout_tensor with no attributes. self.value.dict returns {}. When i try to apply the same function to that first result, i get a normal layout_tensor so i'm not sure what the issue is.

Traceback (most recent call last):
  File "/home/cdhernandez/local/autoquant_flux/autoquant_flux.py", line 34, in <module>
    inference()
  File "/home/cdhernandez/local/autoquant_flux/autoquant_flux.py", line 15, in inference
    out = pipe(
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/external_utils.py", line 39, in inner
    return fn(*args, **kwargs)
  File "/home/cdhernandez/local/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 696, in __call__
    noise_pred = self.transformer(
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/accelerate/hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/accelerate/hooks.py", line 690, in pre_forward
    self.prev_module_hook.offload()
  File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/accelerate/hooks.py", line 691, in torch_dynamo_resume_in_pre_forward_at_690
    module.to(self.execution_device)
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1340, in to
    return self._apply(convert)
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 927, in _apply
    param_applied = fn(param)
  File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1326, in convert
    return t.to(
  File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 299, in to
    self.layout_tensor.to(device),
  File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 427, in to
    kwargs = self._get_to_kwargs(*args, **kwargs)
  File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 428, in torch_dynamo_resume_in_to_at_427
    return self.__class__(
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 1238, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 1039, in __call__
    result = self._inner_convert(
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 514, in __call__
    return _compile(
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 929, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 902, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 653, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/cdhernandez/local/pytorch/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/home/cdhernandez/local/pytorch/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 208, in _fn
    return fn(*args, **kwargs)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 622, in transform
    tracer.run()
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2731, in run
    super().run()
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 958, in run
    while self.step():
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 870, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1715, in STORE_ATTR
    if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable):
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/base.py", line 110, in __instancecheck__
    instance = instance.realize()
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/lazy.py", line 63, in realize
    self._cache.realize()
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/lazy.py", line 29, in realize
    self.vt = VariableBuilder(tx, self.source)(self.value)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 337, in __call__
    vt = self._wrap(value)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 516, in _wrap
    return self.wrap_tensor(value)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 1435, in wrap_tensor
    self.assert_not_wrapped_by_this_graph(value)
  File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 1346, in assert_not_wrapped_by_this_graph
    if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
  File "/home/cdhernandez/local/pytorch/torch/_subclasses/fake_tensor.py", line 175, in is_fake
    attrs, _ = type(x).__tensor_flatten__(x)
  File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 416, in __tensor_flatten__
    return ["int_data", "scale", "zero_point"], [self.layout_type]
torch._dynamo.exc.InternalTorchDynamoError: 'PlainAQTLayout' object has no attribute 'layout_type'

from user code:
   File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 410, in __init__
    self.int_data = int_data

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

code:

import torchao
from torchao.quantization.quant_api import quantize_, int8_weight_only
from torch import nn
from torch.utils.benchmark import Timer

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

def inference():
    prompt = "A cat holding a sign that says hello world"
    out = pipe(
        prompt=prompt,
        guidance_scale=0.,
        height=768,
        width=1360,
        num_inference_steps=4,
        max_sequence_length=256,
    ).images[0]
    out.save("image.png")

with torch.no_grad():
    quantize_(pipe.transformer, int8_weight_only())
    pipe = torch.compile(pipe)
    inference()
    tic = time.time()
    inference()
    toc = time.time()
    print(f"Original Running time is {toc - tic}")

@jerryzh168 have you seen anything like this? it seems like something we should reasonably be able to handle, not sure whats going on.

jerryzh168 commented 3 weeks ago

sorry just saw this, this might be related to pytorch version, is this resolved now? I'll take a look a bit later