sayakpaul / diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Apache License 2.0
271 stars 8 forks source link

Initial results #2

Closed sayakpaul closed 2 months ago

sayakpaul commented 3 months ago
ckpt_id batch_size fuse compile quantization sparsify memory time
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 True False int8dq False 9.853 34.028
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False True None False 10.212 3.893
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False int4wo False 9.608 85.639
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False int8wo False 9.677 10.854
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False None False 10.213 8.44
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False True None False 10.213 7.68
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False None False 10.212 4.296
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 True False int4wo False 9.795 42.739
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 True False int8wo False 9.851 5.544
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False autoquant False 10.236 36.919
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False autoquant False 10.235 21.792
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8dq False 9.85 8.833
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True autoquant False 9.672 1.182
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False int8wo False 9.677 5.585
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8wo False 9.676 1.573
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False int4wo False 9.603 42.934
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int4wo False 9.604 10.894
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8dq False 9.676 11.327
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False int8dq False 9.677 21.834
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True None False 10.214 1.034
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False True autoquant False 9.673 4.18
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int4wo False 9.793 10.813
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8wo False 9.85 1.545
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 True False int8dq False 9.852 19.187
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False None False 10.211 1.195
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False int8dq False 9.678 36.918
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False True autoquant False 9.674 8.132
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 True False int8wo False 9.852 10.818
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 True False int4wo False 9.787 85.226

Here's a CSV for easy analysis in pandas: analysis_torchao.csv

sayakpaul commented 3 months ago

From Mark internally:

Regarding the benchmarks int4wo quantization is optimized for small batch sizes so dont bother trying this for larger

sayakpaul commented 3 months ago

After disabling the compile flags:

ckpt_id batch_size fuse compile quantization sparsify memory time
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 True False int8dq False 9.852 33.967
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False True None False 10.214 3.858
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False int4wo False 9.605 85.653
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False int8wo False 9.677 10.813
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False None False 10.215 8.439
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False True None False 10.215 7.804
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False None False 10.212 4.285
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 True False int4wo False 9.794 42.732
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 True False int8wo False 9.851 5.536
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False autoquant False 10.236 36.625
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False autoquant False 10.235 21.477
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8dq False 9.85 8.688
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True autoquant False 9.672 1.209
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False int8wo False 9.677 5.588
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8wo False 9.675 1.596
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False int4wo False 9.604 42.92
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int4wo False 9.603 10.813
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8dq False 9.676 11.243
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False False int8dq False 9.677 21.853
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True None False 10.211 1.052
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 False True autoquant False 9.673 4.198
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int4wo False 9.793 10.826
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8wo False 9.85 1.537
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 4 True False int8dq False 9.851 19.047
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False None False 10.214 1.187
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False False int8dq False 9.677 36.674
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 False True autoquant False 9.674 8.135
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 True False int8wo False 9.851 10.799
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 8 True False int4wo False 9.795 85.386

CSV: analysis_torchao_without_compile_flags.csv

Cc: @jerryzh168

sayakpaul commented 3 months ago

Seems like without the compilation flag the numbers take a hit.

old_df.query("batch_size==1").sort_values(by="time")
ckpt_id batch_size fuse compile quantization sparsify memory time
0 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True nan False 10.214 1.034
1 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True autoquant False 9.672 1.182
2 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False nan False 10.211 1.195
3 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8wo False 9.85 1.545
4 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8wo False 9.676 1.573
5 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8dq False 9.85 8.833
6 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int4wo False 9.793 10.813
7 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int4wo False 9.604 10.894
8 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8dq False 9.676 11.327
df.query("batch_size==1").sort_values(by="time").reset_index(drop=True)
ckpt_id batch_size fuse compile quantization sparsify memory time
0 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True None False 10.211 1.052
1 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False None False 10.214 1.187
2 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True autoquant False 9.672 1.209
3 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8wo False 9.85 1.537
4 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8wo False 9.675 1.596
5 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int4wo False 9.603 10.813
6 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int4wo False 9.793 10.826
7 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8dq False 9.676 11.243
8 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8dq False 9.85 8.688
msaroufim commented 3 months ago

For all this perf work we'll need to start looking at traces

sayakpaul commented 3 months ago

Would appreciate any of the two:

sayakpaul commented 3 months ago

@jerryzh168 when I ran the benchmark with these changes https://github.com/sayakpaul/diffusers-torchao/commit/fa9154eeaf9e805f4bff8cfff209215ee425a163 I get:

ckpt_id batch_size fuse compile quantization sparsify memory time
0 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True None False 10.211 1.052
1 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False None False 10.211 1.183
2 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True autoquant False 9.672 1.185
3 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8wo False 9.85 1.538
4 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8wo False 9.676 1.56
5 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int8dq False 9.676 10.747
6 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int4wo False 9.793 10.82
7 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False False int4wo False 9.603 10.889
8 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True False int8dq False 9.849 8.549
jerryzh168 commented 3 months ago

False

@sayakpaul we are mainly interested in compile flag == True, but seems that only two of the tests has them? can you include others as well?

sayakpaul commented 3 months ago

Discovered another bug when running:

python benchmark_pixart.py --quantization=int8wo --compile
Error ```bash Traceback (most recent call last): File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 171, in info = run_benchmark(pipeline, args) File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 110, in run_benchmark run_inference(pipeline, batch_size=args.batch_size) File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 91, in run_inference _ = pipe( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/fsx/sayak/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py", line 834, in __call__ noise_pred = self.transformer( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn return fn(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__ return self._torchdynamo_orig_callable( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__ return _compile( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_utils_internal.py", line 84, in wrapper_function return StrobelightCompileTimeProfiler.profile_compile_time( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time return func(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner return func(*args, **kwds) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper r = func(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner out_code = transform_code_object(code, transform) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object transformations(instructions, code_options) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn return fn(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform tracer.run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run super().run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1512, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function return tx.inline_user_function_return( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ tracer.run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function return super().call_function(tx, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function return super().call_function(tx, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ tracer.run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function return tx.inline_user_function_return( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ tracer.run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function return super().call_function(tx, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function return super().call_function(tx, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ tracer.run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1459, in CALL_FUNCTION self.call_function(fn, args, {}) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function return tx.inline_user_function_return( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ tracer.run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function return super().call_function(tx, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function return super().call_function(tx, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_ tracer.run() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1459, in CALL_FUNCTION self.call_function(fn, args, {}) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/nn_module.py", line 409, in call_function return wrap_fx_proxy( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value ret_val = wrap_fake_exception( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception return fn() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1786, in lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1921, in run_node raise RuntimeError(make_error_message(e)).with_traceback( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1908, in run_node return nnmodule(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 117, in forward return F.linear(input, self.weight, self.bias) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/dtypes/utils.py", line 53, in _dispatch__torch_function__ return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/dtypes/utils.py", line 36, in wrapper return func(f, types, args, kwargs) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 840, in _ weight_tensor = weight_tensor.dequantize() File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 155, in dequantize int_data, scale, zero_point = self.layout_tensor.get_plain() torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___adaln_single_emb_timestep_embedder_linear_1(*(FakeTensor(..., device='cuda:0', size=(2, 256), dtype=torch.bfloat16),), **{}): 'FakeTensor' object has no attribute 'get_plain' from user code: File "/fsx/sayak/diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py", line 371, in forward timestep, embedded_timestep = self.adaln_single( File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 201, in forward embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/fsx/sayak/diffusers/src/diffusers/models/embeddings.py", line 1222, in forward timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/fsx/sayak/diffusers/src/diffusers/models/embeddings.py", line 492, in forward sample = self.linear_1(sample) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True ```
jerryzh168 commented 3 months ago
    return func(f, types, args, kwargs)

what is the pytorch version you are using? if it's not nightly, then you may need to call torchao.utils.unwrap_tensor_subclass in order to make torch.compile work, otherwise it's not needed.

sayakpaul commented 3 months ago

It's not nightly. I will try with nightly.

sayakpaul commented 3 months ago

Nightly is working. Will run the entire suite again and post here.

sayakpaul commented 3 months ago

@jerryzh168 @msaroufim i have much more promising results to share now:

ckpt_id batch_size fuse compile quantization sparsify memory time
0 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True autoquant False 10.203 0.907
1 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True None False 10.214 0.926
2 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True int8wo False 9.676 0.981
3 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True int8wo False 9.849 0.984
4 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True fp6 False 10.034 0.995
5 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True fp6 False 10.301 0.997
6 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True int8dq False 9.849 1.015
7 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True int8dq False 9.675 1.054
8 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True int4wo False 9.798 10.201
9 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True int4wo False 9.609 10.213

This is with Torch nightly and ao installed from the latest commit. https://github.com/sayakpaul/diffusers-torchao/commit/fa9154eeaf9e805f4bff8cfff209215ee425a163 played a bit too inspired from https://github.com/sayakpaul/diffusers-torchao/commit/fa9154eeaf9e805f4bff8cfff209215ee425a163.

LMK WYT.

jerryzh168 commented 3 months ago

1.054

Thanks @sayakpaul this result is similar to what I get actually, int4wo seems to be exploded and all other quant method runs in similar time comapred to fp16. but it's nice to see that autoquant can still get some speedup. I think we can investigate a bit on why int4 weight only quant does not work, maybe a good starting point is to output the generated code and check if something is wrong, example: https://github.com/pytorch/ao/blob/d582f9aba0c1ba5275384d9e81999854b4d951e0/tutorials/quantize_vit/run.sh#L13 (basically set TORCH_LOGS='output_code' when running the benchmark), we can do the same for other types of quantization as well

another question I have is: is this model more compute bound or memory bound? I think weight only quant or dynamic quant would help more with memory bound use cases

sayakpaul commented 3 months ago

Note that I am on bf16 and not on fp16. I can get you the code for int4, first. But LMK if you need any other.

is this model more compute bound or memory bound? I think weight only quant or dynamic quant would help more with memory bound use cases

It is more compute bound. I can use this code to get you a trace as well: https://github.com/huggingface/diffusion-fast/blob/main/run_profile.py

LMK if you have any other suggestion.

jerryzh168 commented 3 months ago

Note that I am on bf16 and not on fp16. I can get you the code for int4, first. But LMK if you need any other.

is this model more compute bound or memory bound? I think weight only quant or dynamic quant would help more with memory bound use cases

It is more compute bound. I can use this code to get you a trace as well: huggingface/diffusion-fast@main/run_profile.py

LMK if you have any other suggestion.

thanks @sayakpaul, a trace might be helpful as well. for compute bound workloads I feel static quant would make more sense, but we can look at the trace to see where is the time spent (a trace for bfloat16, int8dq and int4wo would be helpful I think)

jerryzh168 commented 3 months ago

@sayakpaul maybe you can try out float8 inference static quant as well: https://github.com/pytorch/ao/blob/d582f9aba0c1ba5275384d9e81999854b4d951e0/torchao/float8/inference.py#L34

sayakpaul commented 3 months ago

Dynamic FP8 (compile set to True):

batch size 1:
"memory": "9.672", "time": "0.914"
batch size 4:
"memory": "9.673", "time": "3.259"
batch size 8:
"memory": "9.674", "time": "6.398"

Seems like FP8 is better than vanilla BF16.

sayakpaul commented 3 months ago

@jerryzh168 with static FP8, I am hitting the following:

Traceback (most recent call last):
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1437, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/__init__.py", line 2231, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1230, in compile_fx
    return compile_fx(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1488, in compile_fx
    return aot_autograd(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1019, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1008, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 429, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 730, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 178, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1327, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1394, in _fw_compiler_base
    return inner_compile(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 457, in compile_fx_inner
    return _compile_fx_inner(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/debug.py", line 313, in inner
    return fn(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 446, in wrapper
    return fn(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 644, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 1293, in load
    compiled_graph = compile_fx_fn(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 554, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 855, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1844, in compile_to_fn
    return self.compile_to_module().call
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1770, in compile_to_module
    return self._compile_to_module()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1776, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1717, in codegen
    self.scheduler = Scheduler(self.operations)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 1553, in __init__
    self._init(nodes)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 1623, in _init
    self.nodes = self.fuse_nodes(self.nodes)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2017, in fuse_nodes
    nodes = self.fuse_nodes_once(nodes)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2291, in fuse_nodes_once
    if not self.speedup_by_fusion(node1, node2):
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2214, in speedup_by_fusion
    ms_fused, _ = self.benchmark_fused_nodes(node_list_fused)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/scheduler.py", line 2052, in benchmark_fused_nodes
    return backend.benchmark_fused_nodes(nodes)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 92, in benchmark_fused_nodes
    return self._triton_scheduling.benchmark_fused_nodes(nodes)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/triton.py", line 3028, in benchmark_fused_nodes
    src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/simd.py", line 1805, in generate_kernel_code_from_nodes
    src_code = self.codegen_template(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/codegen/simd.py", line 1489, in codegen_template
    partial_code.finalize_hook("<DEF_KERNEL>")
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 97, in finalize_hook
    self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 306, in hook
    code.splice(self.jit_lines())
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 235, in jit_lines
    num_gb = self.estimate_kernel_num_bytes() / 1e9
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_inductor/select_algorithm.py", line 206, in estimate_kernel_num_bytes
    numel = functools.reduce(operator.mul, size)
TypeError: reduce() of empty sequence with no initial value

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 175, in <module>
    info = run_benchmark(pipeline, args)
  File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 113, in run_benchmark
    run_inference(pipeline, batch_size=args.batch_size)
  File "/fsx/sayak/diffusers-torchao/inference/benchmark_pixart.py", line 94, in run_inference
    _ = pipe(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py", line 834, in __call__
    noise_pred = self.transformer(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 464, in _fn
    return fn(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1224, in __call__
    return self._torchdynamo_orig_callable(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
    return _compile(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 896, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 662, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 697, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 631, in transform
    tracer.run()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2722, in run
    super().run()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 957, in run
    while self.step():
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 869, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2913, in RETURN_VALUE
    self._return(inst)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2898, in _return
    self.output.compile_subgraph(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1133, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1360, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1407, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1456, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: reduce() of empty sequence with no initial value

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

(I am on torch nightly 2.5.0.dev20240806+cu121).

I had to make the following changes to make static quant work:

pipeline.transformer = quantize_to_float8(pipeline.transformer, QuantConfig(ActivationCasting.STATIC, static_quantization_scale=torch.tensor(0.5).to("cuda")))
sayakpaul commented 3 months ago

Weird that with autotuned autoquant, we get worse results:

{"batch_size": 1, "fuse": false, "compile": true, "quantization": "autoquant", "sparsify": false, "memory": "9.676", "time": "1.005"}

Here's the data.pkl: https://huggingface.co/datasets/sayakpaul/torchao-diffusers/blob/main/traces/data.pkl

Experiment was launched with:

export TORCHAO_AUTOTUNER_ENABLE=1
python benchmark_pixart.py --compile --quantization=autoquant
sayakpaul commented 3 months ago

With the latest nightlies in both torch and torchao:

ckpt_id batch_size fuse compile quantization sparsify memory time
PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True autoquant False 10.203 0.958
sayakpaul commented 3 months ago

New results (cc: @jerryzh168):

ckpt_id batch_size fuse compile quantization sparsify memory time
0 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True None False 10.214 0.865
1 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True int8wo False 9.847 0.897
2 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True fp6 False 10.299 0.926
3 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True int8wo False 9.672 0.93
4 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True fp6 False 10.028 0.933
5 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True int8dq False 9.846 0.939
6 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 False True int8dq False 9.673 0.98
7 PixArt-alpha/PixArt-Sigma-XL-2-1024-MS 1 True True int4wo False 9.795 10.142
jerryzh168 commented 3 months ago
{"batch_size": 1, "fuse": false, "compile": true, "quantization": "autoquant", "sparsify": false, "memory": "9.676", "time": "1.005"}

Thanks for the update @sayakpaul, maybe we can try larger models, did you get speedup from quanto with the current model btw

sayakpaul commented 3 months ago

Could you help run SD3 as discussed just to gauge the ballpark?

sayakpaul commented 2 months ago

Closing this issue as it has solved its purpose.