Closed sayakpaul closed 2 months ago
From Mark internally:
Regarding the benchmarks int4wo quantization is optimized for small batch sizes so dont bother trying this for larger
After disabling the compile flags:
ckpt_id | batch_size | fuse | compile | quantization | sparsify | memory | time |
---|---|---|---|---|---|---|---|
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | True | False | int8dq | False | 9.852 | 33.967 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | False | True | None | False | 10.214 | 3.858 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | False | False | int4wo | False | 9.605 | 85.653 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | False | False | int8wo | False | 9.677 | 10.813 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | False | False | None | False | 10.215 | 8.439 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | False | True | None | False | 10.215 | 7.804 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | False | False | None | False | 10.212 | 4.285 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | True | False | int4wo | False | 9.794 | 42.732 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | True | False | int8wo | False | 9.851 | 5.536 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | False | False | autoquant | False | 10.236 | 36.625 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | False | False | autoquant | False | 10.235 | 21.477 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8dq | False | 9.85 | 8.688 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | autoquant | False | 9.672 | 1.209 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | False | False | int8wo | False | 9.677 | 5.588 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8wo | False | 9.675 | 1.596 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | False | False | int4wo | False | 9.604 | 42.92 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int4wo | False | 9.603 | 10.813 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8dq | False | 9.676 | 11.243 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | False | False | int8dq | False | 9.677 | 21.853 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | None | False | 10.211 | 1.052 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | False | True | autoquant | False | 9.673 | 4.198 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int4wo | False | 9.793 | 10.826 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8wo | False | 9.85 | 1.537 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 4 | True | False | int8dq | False | 9.851 | 19.047 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | None | False | 10.214 | 1.187 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | False | False | int8dq | False | 9.677 | 36.674 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | False | True | autoquant | False | 9.674 | 8.135 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | True | False | int8wo | False | 9.851 | 10.799 |
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 8 | True | False | int4wo | False | 9.795 | 85.386 |
CSV: analysis_torchao_without_compile_flags.csv
Cc: @jerryzh168
Seems like without the compilation flag the numbers take a hit.
old_df.query("batch_size==1").sort_values(by="time")
ckpt_id | batch_size | fuse | compile | quantization | sparsify | memory | time | |
---|---|---|---|---|---|---|---|---|
0 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | nan | False | 10.214 | 1.034 |
1 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | autoquant | False | 9.672 | 1.182 |
2 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | nan | False | 10.211 | 1.195 |
3 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8wo | False | 9.85 | 1.545 |
4 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8wo | False | 9.676 | 1.573 |
5 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8dq | False | 9.85 | 8.833 |
6 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int4wo | False | 9.793 | 10.813 |
7 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int4wo | False | 9.604 | 10.894 |
8 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8dq | False | 9.676 | 11.327 |
df.query("batch_size==1").sort_values(by="time").reset_index(drop=True)
ckpt_id | batch_size | fuse | compile | quantization | sparsify | memory | time | |
---|---|---|---|---|---|---|---|---|
0 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | None | False | 10.211 | 1.052 |
1 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | None | False | 10.214 | 1.187 |
2 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | autoquant | False | 9.672 | 1.209 |
3 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8wo | False | 9.85 | 1.537 |
4 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8wo | False | 9.675 | 1.596 |
5 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int4wo | False | 9.603 | 10.813 |
6 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int4wo | False | 9.793 | 10.826 |
7 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8dq | False | 9.676 | 11.243 |
8 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8dq | False | 9.85 | 8.688 |
For all this perf work we'll need to start looking at traces
Would appreciate any of the two:
@jerryzh168 when I ran the benchmark with these changes https://github.com/sayakpaul/diffusers-torchao/commit/fa9154eeaf9e805f4bff8cfff209215ee425a163 I get:
ckpt_id | batch_size | fuse | compile | quantization | sparsify | memory | time | |
---|---|---|---|---|---|---|---|---|
0 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | None | False | 10.211 | 1.052 |
1 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | None | False | 10.211 | 1.183 |
2 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | autoquant | False | 9.672 | 1.185 |
3 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8wo | False | 9.85 | 1.538 |
4 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8wo | False | 9.676 | 1.56 |
5 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int8dq | False | 9.676 | 10.747 |
6 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int4wo | False | 9.793 | 10.82 |
7 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | False | int4wo | False | 9.603 | 10.889 |
8 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | False | int8dq | False | 9.849 | 8.549 |
False
@sayakpaul we are mainly interested in compile flag == True, but seems that only two of the tests has them? can you include others as well?
Discovered another bug when running:
python benchmark_pixart.py --quantization=int8wo --compile
return func(f, types, args, kwargs)
what is the pytorch version you are using? if it's not nightly, then you may need to call torchao.utils.unwrap_tensor_subclass
in order to make torch.compile work, otherwise it's not needed.
It's not nightly. I will try with nightly.
Nightly is working. Will run the entire suite again and post here.
@jerryzh168 @msaroufim i have much more promising results to share now:
ckpt_id | batch_size | fuse | compile | quantization | sparsify | memory | time | |
---|---|---|---|---|---|---|---|---|
0 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | autoquant | False | 10.203 | 0.907 |
1 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | None | False | 10.214 | 0.926 |
2 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | int8wo | False | 9.676 | 0.981 |
3 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | int8wo | False | 9.849 | 0.984 |
4 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | fp6 | False | 10.034 | 0.995 |
5 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | fp6 | False | 10.301 | 0.997 |
6 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | int8dq | False | 9.849 | 1.015 |
7 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | int8dq | False | 9.675 | 1.054 |
8 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | int4wo | False | 9.798 | 10.201 |
9 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | int4wo | False | 9.609 | 10.213 |
This is with Torch nightly and ao
installed from the latest commit. https://github.com/sayakpaul/diffusers-torchao/commit/fa9154eeaf9e805f4bff8cfff209215ee425a163 played a bit too inspired from https://github.com/sayakpaul/diffusers-torchao/commit/fa9154eeaf9e805f4bff8cfff209215ee425a163.
LMK WYT.
1.054
Thanks @sayakpaul this result is similar to what I get actually, int4wo seems to be exploded and all other quant method runs in similar time comapred to fp16. but it's nice to see that autoquant can still get some speedup. I think we can investigate a bit on why int4 weight only quant does not work, maybe a good starting point is to output the generated code and check if something is wrong, example: https://github.com/pytorch/ao/blob/d582f9aba0c1ba5275384d9e81999854b4d951e0/tutorials/quantize_vit/run.sh#L13 (basically set TORCH_LOGS='output_code'
when running the benchmark), we can do the same for other types of quantization as well
another question I have is: is this model more compute bound or memory bound? I think weight only quant or dynamic quant would help more with memory bound use cases
Note that I am on bf16 and not on fp16. I can get you the code for int4, first. But LMK if you need any other.
is this model more compute bound or memory bound? I think weight only quant or dynamic quant would help more with memory bound use cases
It is more compute bound. I can use this code to get you a trace as well: https://github.com/huggingface/diffusion-fast/blob/main/run_profile.py
LMK if you have any other suggestion.
Note that I am on bf16 and not on fp16. I can get you the code for int4, first. But LMK if you need any other.
is this model more compute bound or memory bound? I think weight only quant or dynamic quant would help more with memory bound use cases
It is more compute bound. I can use this code to get you a trace as well: huggingface/diffusion-fast@
main
/run_profile.pyLMK if you have any other suggestion.
thanks @sayakpaul, a trace might be helpful as well. for compute bound workloads I feel static quant would make more sense, but we can look at the trace to see where is the time spent (a trace for bfloat16, int8dq and int4wo would be helpful I think)
@sayakpaul maybe you can try out float8 inference static quant as well: https://github.com/pytorch/ao/blob/d582f9aba0c1ba5275384d9e81999854b4d951e0/torchao/float8/inference.py#L34
Dynamic FP8 (compile set to True):
batch size 1:
"memory": "9.672", "time": "0.914"
batch size 4:
"memory": "9.673", "time": "3.259"
batch size 8:
"memory": "9.674", "time": "6.398"
Seems like FP8 is better than vanilla BF16.
@jerryzh168 with static FP8, I am hitting the following:
Traceback (most recent call last):
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1437, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/__init__.py", line 2231, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1230, in compile_fx
return compile_fx(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1488, in compile_fx
return aot_autograd(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1019, in aot_module_simplified
compiled_fn = dispatch_and_compile()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1008, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 429, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 730, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 178, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1327, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1394, in _fw_compiler_base
return inner_compile(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 457, in compile_fx_inner
return _compile_fx_inner(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/debug.py", line 313, in inner
return fn(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 446, in wrapper
return fn(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 644, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 1293, in load
compiled_graph = compile_fx_fn(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 554, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 855, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1844, in compile_to_fn
return self.compile_to_module().call
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1770, in compile_to_module
return self._compile_to_module()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1776, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1717, in codegen
self.scheduler = Scheduler(self.operations)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 1553, in __init__
self._init(nodes)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 1623, in _init
self.nodes = self.fuse_nodes(self.nodes)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2017, in fuse_nodes
nodes = self.fuse_nodes_once(nodes)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2291, in fuse_nodes_once
if not self.speedup_by_fusion(node1, node2):
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2214, in speedup_by_fusion
ms_fused, _ = self.benchmark_fused_nodes(node_list_fused)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2052, in benchmark_fused_nodes
return backend.benchmark_fused_nodes(nodes)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 92, in benchmark_fused_nodes
return self._triton_scheduling.benchmark_fused_nodes(nodes)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/triton.py", line 3028, in benchmark_fused_nodes
src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/simd.py", line 1805, in generate_kernel_code_from_nodes
src_code = self.codegen_template(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/simd.py", line 1489, in codegen_template
partial_code.finalize_hook("<DEF_KERNEL>")
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 97, in finalize_hook
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 306, in hook
code.splice(self.jit_lines())
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 235, in jit_lines
num_gb = self.estimate_kernel_num_bytes() / 1e9
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 206, in estimate_kernel_num_bytes
numel = functools.reduce(operator.mul, size)
TypeError: reduce() of empty sequence with no initial value
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 175, in <module>
info = run_benchmark(pipeline, args)
File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 113, in run_benchmark
run_inference(pipeline, batch_size=args.batch_size)
File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 94, in run_inference
_ = pipe(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/fsx/sayak/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py", line 834, in __call__
noise_pred = self.transformer(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 464, in _fn
return fn(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1224, in __call__
return self._torchdynamo_orig_callable(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
return _compile(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 896, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 662, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 697, in _compile_inner
out_code = transform_code_object(code, transform)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 631, in transform
tracer.run()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2722, in run
super().run()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
while self.step():
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
self.dispatch_table[inst.opcode](self, inst)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2913, in RETURN_VALUE
self._return(inst)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2898, in _return
self.output.compile_subgraph(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1133, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1360, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1407, in call_user_compiler
return self._call_user_compiler(gm)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1456, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: reduce() of empty sequence with no initial value
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
(I am on torch nightly 2.5.0.dev20240806+cu121).
I had to make the following changes to make static quant work:
pipeline.transformer = quantize_to_float8(pipeline.transformer, QuantConfig(ActivationCasting.STATIC, static_quantization_scale=torch.tensor(0.5).to("cuda")))
Weird that with autotuned autoquant
, we get worse results:
{"batch_size": 1, "fuse": false, "compile": true, "quantization": "autoquant", "sparsify": false, "memory": "9.676", "time": "1.005"}
Here's the data.pkl
: https://huggingface.co/datasets/sayakpaul/torchao-diffusers/blob/main/traces/data.pkl
Experiment was launched with:
export TORCHAO_AUTOTUNER_ENABLE=1
python benchmark_pixart.py --compile --quantization=autoquant
With the latest nightlies in both torch and torchao:
ckpt_id | batch_size | fuse | compile | quantization | sparsify | memory | time |
---|---|---|---|---|---|---|---|
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | autoquant | False | 10.203 | 0.958 |
New results (cc: @jerryzh168):
ckpt_id | batch_size | fuse | compile | quantization | sparsify | memory | time | |
---|---|---|---|---|---|---|---|---|
0 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | None | False | 10.214 | 0.865 |
1 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | int8wo | False | 9.847 | 0.897 |
2 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | fp6 | False | 10.299 | 0.926 |
3 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | int8wo | False | 9.672 | 0.93 |
4 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | fp6 | False | 10.028 | 0.933 |
5 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | int8dq | False | 9.846 | 0.939 |
6 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | False | True | int8dq | False | 9.673 | 0.98 |
7 | PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 1 | True | True | int4wo | False | 9.795 | 10.142 |
{"batch_size": 1, "fuse": false, "compile": true, "quantization": "autoquant", "sparsify": false, "memory": "9.676", "time": "1.005"}
Thanks for the update @sayakpaul, maybe we can try larger models, did you get speedup from quanto with the current model btw
Could you help run SD3 as discussed just to gauge the ballpark?
Closing this issue as it has solved its purpose.
Here's a CSV for easy analysis in
pandas
: analysis_torchao.csv