Open msaroufim opened 2 months ago
looked into this, the problem isn't autoquant, the model can't be quantized at all seemingly, it has a very weird structure with prehooks that don't seem to play well with quantized tensor subclasses
i can't actually figure out what the issue is
it seems in the prehook in pipe.transformer (e.g. hooks.py lines 161, 690, 691 in the trace), has a .to method that causes some weirdness to happen. .to on the affine_quantized_tensor calls .to on the layout tensor and.... eventually a PlainAQTLayout is created with no attributes but it happens in the trace between the
File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 299, in to self.layout_tensor.to(device),
and
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/lazy.py", line 29, in realize self.vt = VariableBuilder(tx, self.source)(self.value)
lines. I breakpointed to try to find the point where the change happens but i was unable to, the first is a normal layout_tensor, the second is a layout_tensor with no attributes. self.value.dict returns {}. When i try to apply the same function to that first result, i get a normal layout_tensor so i'm not sure what the issue is.
Traceback (most recent call last):
File "/home/cdhernandez/local/autoquant_flux/autoquant_flux.py", line 34, in <module>
inference()
File "/home/cdhernandez/local/autoquant_flux/autoquant_flux.py", line 15, in inference
out = pipe(
File "/home/cdhernandez/local/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/external_utils.py", line 39, in inner
return fn(*args, **kwargs)
File "/home/cdhernandez/local/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 696, in __call__
noise_pred = self.transformer(
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/accelerate/hooks.py", line 161, in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/accelerate/hooks.py", line 690, in pre_forward
self.prev_module_hook.offload()
File "/home/cdhernandez/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/accelerate/hooks.py", line 691, in torch_dynamo_resume_in_pre_forward_at_690
module.to(self.execution_device)
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1340, in to
return self._apply(convert)
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 900, in _apply
module._apply(fn)
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 900, in _apply
module._apply(fn)
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 900, in _apply
module._apply(fn)
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 927, in _apply
param_applied = fn(param)
File "/home/cdhernandez/local/pytorch/torch/nn/modules/module.py", line 1326, in convert
return t.to(
File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 299, in to
self.layout_tensor.to(device),
File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 427, in to
kwargs = self._get_to_kwargs(*args, **kwargs)
File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 428, in torch_dynamo_resume_in_to_at_427
return self.__class__(
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 1238, in __call__
return self._torchdynamo_orig_callable(
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 1039, in __call__
result = self._inner_convert(
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 514, in __call__
return _compile(
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 929, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 902, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 653, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/cdhernandez/local/pytorch/torch/_utils_internal.py", line 85, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/home/cdhernandez/local/pytorch/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 208, in _fn
return fn(*args, **kwargs)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/convert_frame.py", line 622, in transform
tracer.run()
File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2731, in run
super().run()
File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 958, in run
while self.step():
File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 870, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1715, in STORE_ATTR
if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable):
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/base.py", line 110, in __instancecheck__
instance = instance.realize()
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/lazy.py", line 63, in realize
self._cache.realize()
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/lazy.py", line 29, in realize
self.vt = VariableBuilder(tx, self.source)(self.value)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 337, in __call__
vt = self._wrap(value)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 516, in _wrap
return self.wrap_tensor(value)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 1435, in wrap_tensor
self.assert_not_wrapped_by_this_graph(value)
File "/home/cdhernandez/local/pytorch/torch/_dynamo/variables/builder.py", line 1346, in assert_not_wrapped_by_this_graph
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
File "/home/cdhernandez/local/pytorch/torch/_subclasses/fake_tensor.py", line 175, in is_fake
attrs, _ = type(x).__tensor_flatten__(x)
File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 416, in __tensor_flatten__
return ["int_data", "scale", "zero_point"], [self.layout_type]
torch._dynamo.exc.InternalTorchDynamoError: 'PlainAQTLayout' object has no attribute 'layout_type'
from user code:
File "/home/cdhernandez/local/ao/torchao/dtypes/affine_quantized_tensor.py", line 410, in __init__
self.int_data = int_data
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
code:
import torchao
from torchao.quantization.quant_api import quantize_, int8_weight_only
from torch import nn
from torch.utils.benchmark import Timer
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
def inference():
prompt = "A cat holding a sign that says hello world"
out = pipe(
prompt=prompt,
guidance_scale=0.,
height=768,
width=1360,
num_inference_steps=4,
max_sequence_length=256,
).images[0]
out.save("image.png")
with torch.no_grad():
quantize_(pipe.transformer, int8_weight_only())
pipe = torch.compile(pipe)
inference()
tic = time.time()
inference()
toc = time.time()
print(f"Original Running time is {toc - tic}")
@jerryzh168 have you seen anything like this? it seems like something we should reasonably be able to handle, not sure whats going on.
sorry just saw this, this might be related to pytorch version, is this resolved now? I'll take a look a bit later
Context
I was trying to run the new Flux model but ran into some sharp bits with the autoquant API
What I tried
The baseline was compiling the model which made it about 25% faster
So I tried compiling autoquant at first like this
And the problem was while
torch.compile
works with both nn modules and functions.torchao.autoquant
works only with nn modulesSo instead I tried running over the transformer of the pipe only
Another idea was to compile the pipe and autoquant the transformer
Dependencies
pip freeze
accelerate==0.33.0 aiofiles==23.2.1 altair==5.4.0 annotated-types==0.7.0 anyio==4.4.0 attrs==24.2.0 blinker==1.8.2 cachetools==5.4.0 certifi==2024.7.4 charset-normalizer==3.3.2 click==8.1.7 contourpy==1.2.1 cycler==0.12.1 diffusers==0.30.0 einops==0.8.0 exceptiongroup==1.2.2 fastapi==0.112.0 ffmpy==0.4.0 filelock==3.13.1 fire==0.6.0 -e git+https://github.com/black-forest-labs/flux@c23ae247225daba30fbd56058d247cc1b1fc20a3#egg=flux fonttools==4.53.1 fsspec==2024.6.1 gitdb==4.0.11 GitPython==3.1.43 gradio==4.41.0 gradio_client==1.3.0 h11==0.14.0 httpcore==1.0.5 httpx==0.27.0 huggingface-hub==0.24.5 idna==3.7 importlib_metadata==8.2.0 importlib_resources==6.4.0 invisible-watermark==0.2.0 Jinja2==3.1.4 jsonschema==4.23.0 jsonschema-specifications==2023.12.1 kiwisolver==1.4.5 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.1.post1 mdurl==0.1.2 mpmath==1.3.0 narwhals==1.3.0 networkx==3.3 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.1.105 nvidia-nvtx-cu12==12.1.105 opencv-python==4.10.0.84 orjson==3.10.7 packaging==24.1 pandas==2.2.2 pillow==10.4.0 protobuf==5.27.3 psutil==6.0.0 pyarrow==17.0.0 pydantic==2.8.2 pydantic_core==2.20.1 pydeck==0.9.1 pydub==0.25.1 Pygments==2.18.0 pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-multipart==0.0.9 pytorch-triton==3.0.0+dedb7bdf33 pytz==2024.1 PyWavelets==1.6.0 PyYAML==6.0.2 referencing==0.35.1 regex==2024.7.24 requests==2.32.3 rich==13.7.1 rpds-py==0.20.0 ruff==0.5.7 safetensors==0.4.4 semantic-version==2.10.0 sentencepiece==0.2.0 shellingham==1.5.4 six==1.16.0 smmap==5.0.1 sniffio==1.3.1 starlette==0.37.2 streamlit==1.37.1 streamlit-keyup==0.2.4 sympy==1.13.1 tenacity==8.5.0 termcolor==2.4.0 tokenizers==0.19.1 toml==0.10.2 tomlkit==0.12.0 torch==2.5.0.dev20240811+cu121 torchao @ file:///home/marksaroufim/ao torchvision==0.19.0 tornado==6.4.1 tqdm==4.66.5 transformers==4.44.0 triton==3.0.0 typer==0.12.3 typing_extensions==4.12.2 tzdata==2024.1 urllib3==2.2.2 uvicorn==0.30.5 watchdog==4.0.2 websockets==12.0 zipp==3.20.0