sayakpaul / diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Apache License 2.0
271 stars 8 forks source link

Issue with torchao version maybe? #19

Closed gradjitta closed 2 months ago

gradjitta commented 2 months ago

Hi @sayakpaul! Thanks for all your work related to Flux and diffusers. I came across your post on X about faster flux

I have been following the snippet you shared there, but I came across an error during compile with torchao

from diffusers import FluxPipeline
from torchao.quantization import autoquant
import torch
import os

from torchao.quantization import (
    autoquant,
)

pipeline = FluxPipeline.from_pretrained(PATH_TO_DEV, torch_dtype=torch.bfloat16).to("cuda")
pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)
pipeline.transformer.to(memory_format=torch.channels_last)

pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)

#warmup
for _ in range(3):
    _ = pipeline("a forest", num_inference_steps=30, guidance_scale=3.5)

The warmup triggers the error below

Compile error ```python { "name": "InternalTorchDynamoError", "message": "TypeError: _make_wrapper_subclass(): argument 'dtype' must be torch.dtype, not torch._C._TensorMeta from user code: File \"/home/mlops/flux-stuff/ao/torchao/quantization/autoquant.py\", line 651, in autoquant_prehook real_model.forward(*args, **kwargs) File \"/home/mlops/flux-stuff/diffusers/src/diffusers/models/transformers/transformer_flux.py\", line 442, in forward hidden_states = self.x_embedder(hidden_states) File \"/home/mlops/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/linear.py\", line 125, in forward return F.linear(input, self.weight, self.bias) Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True ", "stack": "--------------------------------------------------------------------------- InternalTorchDynamoError Traceback (most recent call last) Cell In[7], line 3 1 #warmup 2 for _ in range(3): ----> 3 _ = pipeline(\"a forest\", num_inference_steps=30, guidance_scale=3.5) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs) 113 @functools.wraps(func) 114 def decorate_context(*args, **kwargs): 115 with ctx_factory(): --> 116 return func(*args, **kwargs) File ~/flux-stuff/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py:719, in FluxPipeline.__call__(self, prompt, prompt_2, height, width, num_inference_steps, timesteps, guidance_scale, num_images_per_prompt, generator, latents, prompt_embeds, pooled_prompt_embeds, output_type, return_dict, joint_attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length) 716 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 717 timestep = t.expand(latents.shape[0]).to(latents.dtype) --> 719 noise_pred = self.transformer( 720 hidden_states=latents, 721 timestep=timestep / 1000, 722 guidance=guidance, 723 pooled_projections=pooled_prompt_embeds, 724 encoder_hidden_states=prompt_embeds, 725 txt_ids=text_ids, 726 img_ids=latent_image_ids, 727 joint_attention_kwargs=self.joint_attention_kwargs, 728 return_dict=False, 729 )[0] 731 # compute the previous noisy sample x_t -> x_t-1 732 latents_dtype = latents.dtype File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:465, in _TorchDynamoContext.__call__.._fn(*args, **kwargs) 460 saved_dynamic_layer_stack_depth = ( 461 torch._C._functorch.get_dynamic_layer_stack_depth() 462 ) 464 try: --> 465 return fn(*args, **kwargs) 466 finally: 467 # Restore the dynamic layer stack depth if necessary. 468 torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( 469 saved_dynamic_layer_stack_depth 470 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/module.py:1844, in Module._call_impl(self, *args, **kwargs) 1841 return inner() 1843 try: -> 1844 return inner() 1845 except Exception: 1846 # run always called hooks if they have not already been run 1847 # For now only forward hooks have the always_call option but perhaps 1848 # this functionality should be added to full backward hooks as well. 1849 for hook_id, hook in _global_forward_hooks.items(): File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/module.py:1769, in Module._call_impl..inner() 1764 for hook_id, hook in ( 1765 *_global_forward_pre_hooks.items(), 1766 *self._forward_pre_hooks.items(), 1767 ): 1768 if hook_id in self._forward_pre_hooks_with_kwargs: -> 1769 args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] 1770 if args_kwargs_result is not None: 1771 if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:1252, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state) 1246 return hijacked_callback( 1247 frame, cache_entry, self.hooks, frame_state 1248 ) 1250 with compile_lock, _disable_current_modes(): 1251 # skip=1: skip this frame -> 1252 return self._torchdynamo_orig_callable( 1253 frame, cache_entry, self.hooks, frame_state, skip=1 1254 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:523, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip) 507 compile_id = CompileId(frame_id, frame_compile_id) 509 signpost_event( 510 \"dynamo\", 511 \"_convert_frame_assert._compile\", (...) 520 }, 521 ) --> 523 return _compile( 524 frame.f_code, 525 frame.f_globals, 526 frame.f_locals, 527 frame.f_builtins, 528 self._torchdynamo_orig_callable, 529 self._one_graph, 530 self._export, 531 self._export_constraints, 532 hooks, 533 cache_entry, 534 cache_size, 535 frame, 536 frame_state=frame_state, 537 compile_id=compile_id, 538 skip=skip + 1, 539 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:943, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip) 940 raise 941 else: 942 # Rewrap for clarity --> 943 raise InternalTorchDynamoError( 944 f\"{type(e).__qualname__}: {str(e)}\" 945 ).with_traceback(e.__traceback__) from None 946 finally: 947 if tracer: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:915, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip) 913 guarded_code = None 914 try: --> 915 guarded_code = compile_inner(code, one_graph, hooks, transform) 916 return guarded_code 917 except Exception as e: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:663, in _compile..compile_inner(code, one_graph, hooks, transform) 661 with dynamo_timed(\"_compile.compile_inner\", phase_name=\"entire_frame_compile\"): 662 with CompileTimeInstructionCounter.record(): --> 663 return _compile_inner(code, one_graph, hooks, transform) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_utils_internal.py:87, in compile_time_strobelight_meta..compile_time_strobelight_meta_inner..wrapper_function(*args, **kwargs) 84 kwargs[\"skip\"] = kwargs[\"skip\"] + 1 86 if not StrobelightCompileTimeProfiler.enabled: ---> 87 return function(*args, **kwargs) 89 return StrobelightCompileTimeProfiler.profile_compile_time( 90 function, phase_name, *args, **kwargs 91 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:696, in _compile.._compile_inner(code, one_graph, hooks, transform) 694 CompileContext.get().attempt = attempt 695 try: --> 696 out_code = transform_code_object(code, transform) 697 break 698 except exc.RestartAnalysis as e: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py:1322, in transform_code_object(code, transformations, safe) 1319 instructions = cleaned_instructions(code, safe) 1320 propagate_line_nums(instructions) -> 1322 transformations(instructions, code_options) 1323 return clean_and_assemble_instructions(instructions, keys, code_options)[1] File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:216, in preserve_global_state.._fn(*args, **kwargs) 212 exit_stack.enter_context( 213 torch.fx._symbolic_trace._maybe_revert_all_patches() 214 ) 215 try: --> 216 return fn(*args, **kwargs) 217 finally: 218 cleanup.close() File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:631, in _compile..transform(instructions, code_options) 629 try: 630 with tracing(tracer.output.tracing_context), tracer.set_current_tx(): --> 631 tracer.run() 632 except exc.UnspecializeRestartAnalysis: 633 speculation_log.clear() File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:2796, in InstructionTranslator.run(self) 2795 def run(self): -> 2796 super().run() File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self) 981 try: 982 self.output.push_tx(self) --> 983 while self.step(): 984 pass 985 except BackendCompilerFailed: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self) 892 self.update_block_stack(inst) 894 try: --> 895 self.dispatch_table[inst.opcode](self, inst) 896 return not self.output.should_exit 897 except exc.ObservedException as e: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported..decorator..wrapper(self, inst) 580 return handle_graph_break(self, inst, speculation.reason) 581 try: --> 582 return inner_fn(self, inst) 583 except Unsupported as excp: 584 if self.generic_context_manager_depth > 0: 585 # We don't support graph break under GenericContextWrappingVariable, 586 # If there is, we roll back to the checkpoint and fall back. File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1680, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst) 1678 # Map to a dictionary of str -> VariableTracker 1679 kwargsvars = kwargsvars.keys_as_python_constant() -> 1680 self.call_function(fn, argsvars.items, kwargsvars) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs) 828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): 829 raise AssertionError(f\"Attempt to trace forbidden callable {inner_fn}\") --> 830 self.push(fn.call_function(self, args, kwargs)) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:385, in UserMethodVariable.call_function(self, tx, args, kwargs) 383 fn = getattr(self.obj.value, self.fn.__name__) 384 return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) --> 385 return super().call_function(tx, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs) 319 if self.is_constant: 320 return invoke_and_store_as_constant( 321 tx, self.fn, self.get_name(), args, kwargs 322 ) --> 324 return super().call_function(tx, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs) 105 def call_function( 106 self, 107 tx: \"InstructionTranslator\", 108 args: \"List[VariableTracker]\", 109 kwargs: \"Dict[str, VariableTracker]\", 110 ) -> \"VariableTracker\": --> 111 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs) 832 def inline_user_function_return(self, fn, args, kwargs): 833 \"\"\" 834 A call to some user defined function by inlining it. 835 \"\"\" --> 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs) 3008 @classmethod 3009 def inline_call(cls, parent, func, args, kwargs): 3010 with patch.dict(counters, {\"unimplemented\": counters[\"inline_call\"]}): -> 3011 return cls.inline_call_(parent, func, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs) 3137 try: 3138 with strict_ctx: -> 3139 tracer.run() 3140 except exc.ObservedException as e: 3141 msg = f\"Observed exception DURING INLING {code} : {e}\" File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self) 981 try: 982 self.output.push_tx(self) --> 983 while self.step(): 984 pass 985 except BackendCompilerFailed: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self) 892 self.update_block_stack(inst) 894 try: --> 895 self.dispatch_table[inst.opcode](self, inst) 896 return not self.output.should_exit 897 except exc.ObservedException as e: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported..decorator..wrapper(self, inst) 580 return handle_graph_break(self, inst, speculation.reason) 581 try: --> 582 return inner_fn(self, inst) 583 except Unsupported as excp: 584 if self.generic_context_manager_depth > 0: 585 # We don't support graph break under GenericContextWrappingVariable, 586 # If there is, we roll back to the checkpoint and fall back. File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:2279, in InstructionTranslatorBase.CALL(self, inst) 2277 @break_graph_if_unsupported(push=1) 2278 def CALL(self, inst): -> 2279 self._call(inst) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:2273, in InstructionTranslatorBase._call(self, inst, call_kw) 2268 kwargs = {} 2270 try: 2271 # if call_function fails, need to set kw_names to None, otherwise 2272 # a subsequent call may have self.kw_names set to an old value -> 2273 self.call_function(fn, args, kwargs) 2274 finally: 2275 self.kw_names = None File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs) 828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): 829 raise AssertionError(f\"Attempt to trace forbidden callable {inner_fn}\") --> 830 self.push(fn.call_function(self, args, kwargs)) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py:156, in _create_realize_and_forward..realize_and_forward(self, *args, **kwargs) 154 @functools.wraps(getattr(VariableTracker, name)) 155 def realize_and_forward(self, *args, **kwargs): --> 156 return getattr(self.realize(), name)(*args, **kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py:899, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs) 891 ctx = ( 892 record_nn_module_stack( 893 str(id(mod)), self.get_nn_module_stack_source(), tx, mod (...) 896 else nullcontext() 897 ) 898 with ctx: --> 899 return variables.UserFunctionVariable(fn, source=source).call_function( 900 tx, [self] + list(args), kwargs 901 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs) 319 if self.is_constant: 320 return invoke_and_store_as_constant( 321 tx, self.fn, self.get_name(), args, kwargs 322 ) --> 324 return super().call_function(tx, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs) 105 def call_function( 106 self, 107 tx: \"InstructionTranslator\", 108 args: \"List[VariableTracker]\", 109 kwargs: \"Dict[str, VariableTracker]\", 110 ) -> \"VariableTracker\": --> 111 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs) 832 def inline_user_function_return(self, fn, args, kwargs): 833 \"\"\" 834 A call to some user defined function by inlining it. 835 \"\"\" --> 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs) 3008 @classmethod 3009 def inline_call(cls, parent, func, args, kwargs): 3010 with patch.dict(counters, {\"unimplemented\": counters[\"inline_call\"]}): -> 3011 return cls.inline_call_(parent, func, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs) 3137 try: 3138 with strict_ctx: -> 3139 tracer.run() 3140 except exc.ObservedException as e: 3141 msg = f\"Observed exception DURING INLING {code} : {e}\" File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self) 981 try: 982 self.output.push_tx(self) --> 983 while self.step(): 984 pass 985 except BackendCompilerFailed: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self) 892 self.update_block_stack(inst) 894 try: --> 895 self.dispatch_table[inst.opcode](self, inst) 896 return not self.output.should_exit 897 except exc.ObservedException as e: File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1744, in InstructionTranslatorBase.LOAD_ATTR(self, inst) 1742 self.LOAD_METHOD(inst) 1743 return -> 1744 self._load_attr(inst) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1734, in InstructionTranslatorBase._load_attr(self, inst) 1732 def _load_attr(self, inst): 1733 obj = self.pop() -> 1734 result = BuiltinVariable(getattr).call_function( 1735 self, [obj, ConstantVariable.create(inst.argval)], {} # type: ignore[arg-type] 1736 ) 1737 self.push(result) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py:967, in BuiltinVariable.call_function(self, tx, args, kwargs) 963 if not handler: 964 self.call_function_handler_cache[key] = handler = self._make_handler( 965 self.fn, [type(x) for x in args], bool(kwargs) 966 ) --> 967 return handler(tx, args, kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py:848, in BuiltinVariable._make_handler..builtin_dispatch(tx, args, kwargs) 846 def builtin_dispatch(tx: \"InstructionTranslator\", args, kwargs): 847 for fn in handlers: --> 848 rv = fn(tx, args, kwargs) 849 if rv: 850 return rv File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py:766, in BuiltinVariable._make_handler..call_self_handler(tx, args, kwargs) 764 def call_self_handler(tx: \"InstructionTranslator\", args, kwargs): 765 try: --> 766 result = self_handler(tx, *args, **kwargs) 767 if result is not None: 768 return result File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py:1715, in BuiltinVariable.call_getattr(self, tx, obj, name_var, default) 1703 elif isinstance( 1704 obj, 1705 ( (...) 1712 ), 1713 ): 1714 try: -> 1715 return obj.var_getattr(tx, name) 1716 except NotImplementedError: 1717 return GetAttrVariable(obj, name, **options) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py:1120, in UnspecializedNNModuleVariable.var_getattr(self, tx, name) 1113 result = dict( 1114 build_key_value(i, k, v) for i, (k, v) in enumerate(hooks_dict.items()) 1115 ) 1117 return variables.ConstDictVariable( 1118 result, type(hooks_dict), source=hooks_dict_source 1119 ) -> 1120 return super().var_getattr(tx, name) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/user_defined.py:1032, in UserDefinedObjectVariable.var_getattr(self, tx, name) 1027 out = variables.UserMethodVariable( 1028 getattr_fn, self, source=new_source 1029 ).call_function(tx, [ConstantVariable.create(name)], {}) 1031 if self.source and getattr_fn is torch.nn.Module.__getattr__: -> 1032 if isinstance( 1033 out, 1034 ( 1035 variables.UnspecializedNNModuleVariable, 1036 variables.NNModuleVariable, 1037 ), 1038 ): 1039 # nn_module_stack source is BC surface area. Ensure that 1040 # mod._modules[\"linear\"] is reflected as mod.linear for 1041 # nn_module_stack. 1042 out.set_nn_module_stack_source( 1043 AttrSource(self.get_nn_module_stack_source(), name) 1044 ) 1045 return out File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/base.py:110, in VariableTrackerMeta.__instancecheck__(cls, instance) 103 \"\"\"Make isinstance work with LazyVariableTracker\"\"\" 104 if type.__instancecheck__( 105 variables.LazyVariableTracker, instance 106 ) and cls not in ( 107 VariableTracker, 108 variables.LazyVariableTracker, 109 ): --> 110 instance = instance.realize() 111 return type.__instancecheck__(cls, instance) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py:63, in LazyVariableTracker.realize(self) 61 \"\"\"Force construction of the real VariableTracker\"\"\" 62 if self._cache.vt is None: ---> 63 self._cache.realize() 64 assert self._cache.vt is not None 65 return self._cache.vt File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py:29, in LazyCache.realize(self) 27 self.vt = SourcelessBuilder.create(tx, self.value) 28 else: ---> 29 self.vt = VariableBuilder(tx, self.source)(self.value) 31 del self.value 32 del self.source File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py:377, in VariableBuilder.__call__(self, value) 374 if cached_vt: 375 return cached_vt --> 377 vt = self._wrap(value) 378 vt.source = self.source 379 if ( 380 self._can_lift_attrs_to_inputs(vt) 381 and value not in self.tx.output.side_effects 382 and not is_wrapper_or_member_descriptor(value) 383 ): File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py:559, in VariableBuilder._wrap(self, value) 555 # Everything else (NB: order matters!) 556 if is_traceable_wrapper_subclass(value) or istype( 557 value, config.traceable_tensor_subclasses 558 ): --> 559 return self.wrap_tensor(value) 560 elif is_namedtuple(value): 561 return self.wrap_listlike(value) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py:1593, in VariableBuilder.wrap_tensor(self, value) 1581 if ( 1582 safe_has_grad(value) 1583 and safe_grad(value) is not None 1584 and value.dtype != safe_grad(value).dtype 1585 ): 1586 unimplemented( 1587 \"Inconsistent dtype between tensor and its gradient. \" 1588 \"This can happen in FSDP and crashes meta tensor creation. \" 1589 \"This is potentially a workaround. Fixing it correctly \" 1590 \"requires some design around FSDP + torch.compile.\" 1591 ) -> 1593 tensor_variable = wrap_fx_proxy( 1594 tx=self.tx, 1595 proxy=tensor_proxy, 1596 example_value=value, 1597 subclass_type=subclass_type, 1598 source=source, 1599 **options, 1600 ) 1602 guard_type = GuardBuilder.TENSOR_MATCH 1604 if isinstance(source, GradSource) and is_from_optimizer_source(source): File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py:2037, in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options) 2029 kwargs = { 2030 \"tx\": tx, 2031 \"proxy\": proxy, (...) 2034 **options, 2035 } 2036 if subclass_type is None: -> 2037 return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) 2038 else: 2039 result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py:2149, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options) 2147 assert \"source\" in options and options[\"source\"] is not None 2148 kwargs[\"source\"] = options[\"source\"] -> 2149 example_value = wrap_to_fake_tensor_and_record( 2150 example_value, tx=tx, **kwargs 2151 ) 2152 if ( 2153 isinstance(example_value, torch.Tensor) 2154 and example_value.device.type != \"meta\" 2155 and (maybe_get_fake_mode(example_value) is not tx.fake_mode) 2156 ): 2157 raise InternalTorchDynamoError( 2158 \"`example_value` needs to be a `FakeTensor`\" 2159 f\"wrapped by this instance of Dynamo. Found: {example_value}\" 2160 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py:2707, in wrap_to_fake_tensor_and_record(e, tx, source, is_tensor, parent_context) 2698 symbolic_context = parent_context.inner_contexts[inner_context_name] 2700 log.debug( 2701 \"wrap_to_fake %s %s %s %s\", 2702 source.name(), (...) 2705 type(e), 2706 ) -> 2707 fake_e = wrap_fake_exception( 2708 lambda: tx.fake_mode.from_tensor( 2709 e, 2710 source=source, 2711 symbolic_context=symbolic_context, 2712 ) 2713 ) 2714 if ( 2715 source is not None 2716 and isinstance(fake_e, FakeTensor) 2717 and (sym_val := fake_e.item_memo) is not None 2718 ): 2719 tx.output.tracked_fakes.append( 2720 TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context) 2721 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/utils.py:1574, in wrap_fake_exception(fn) 1572 def wrap_fake_exception(fn): 1573 try: -> 1574 return fn() 1575 except UnsupportedFakeTensorException as e: 1576 from .exc import unimplemented File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py:2708, in wrap_to_fake_tensor_and_record..() 2698 symbolic_context = parent_context.inner_contexts[inner_context_name] 2700 log.debug( 2701 \"wrap_to_fake %s %s %s %s\", 2702 source.name(), (...) 2705 type(e), 2706 ) 2707 fake_e = wrap_fake_exception( -> 2708 lambda: tx.fake_mode.from_tensor( 2709 e, 2710 source=source, 2711 symbolic_context=symbolic_context, 2712 ) 2713 ) 2714 if ( 2715 source is not None 2716 and isinstance(fake_e, FakeTensor) 2717 and (sym_val := fake_e.item_memo) is not None 2718 ): 2719 tx.output.tracked_fakes.append( 2720 TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context) 2721 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py:2238, in FakeTensorMode.from_tensor(self, tensor, static_shapes, source, symbolic_context, trace) 2234 assert ( 2235 symbolic_context is None 2236 ), \"cannot set both static_shapes and symbolic_context\" 2237 shape_env = None -> 2238 return self.fake_tensor_converter.from_real_tensor( 2239 self, 2240 tensor, 2241 shape_env=shape_env, 2242 source=source, 2243 symbolic_context=symbolic_context, 2244 trace=trace, 2245 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py:375, in FakeTensorConverter.from_real_tensor(self, fake_mode, t, make_constant, shape_env, source, symbolic_context, trace) 365 with no_dispatch(): 366 return FakeTensor( 367 fake_mode, 368 make_meta_t(), (...) 372 constant=t if make_constant else None, 373 ) --> 375 out = self.meta_converter( 376 t, 377 shape_env=shape_env, 378 callback=mk_fake_tensor, 379 source=source, 380 symbolic_context=symbolic_context, 381 trace=trace, 382 ) 383 if out is NotImplemented: 384 raise UnsupportedFakeTensorException(\"meta converter nyi\") File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py:1660, in MetaConverter.__call__(self, t, shape_env, callback, source, symbolic_context, trace) 1655 if st is not None: 1656 exit_stack.enter_context( 1657 torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() 1658 ) -> 1660 r = self.meta_tensor( 1661 t_desc, 1662 shape_env=shape_env, 1663 callback=callback, 1664 source=source, 1665 symbolic_context=symbolic_context, 1666 ) 1668 if type(t) is torch.nn.Parameter: 1669 # NB: Cannot directly use Parameter constructor 1670 # because that would force a detach, not desirable 1671 r._is_param = True File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py:1446, in MetaConverter.meta_tensor(self, t, shape_env, callback, source, symbolic_context) 1443 # If we have a subclass that desugars into dense tensors, 1444 # perform our callback on each inner tensor. 1445 if t.is_traceable_wrapper_subclass: -> 1446 r = empty_create_subclass( 1447 t, outer_size=sizes, outer_stride=strides 1448 ) 1449 else: 1450 r = callback( 1451 lambda: torch.empty_strided( 1452 sizes, (...) 1456 ) 1457 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py:859, in MetaConverter.meta_tensor..empty_create_subclass(t, outer_size, outer_stride, symbolic_context, callback, source) 853 inner_tensors[attr] = new_empty_tensor 855 return t.type.__tensor_unflatten__( 856 inner_tensors, t.ctx, outer_size, outer_stride 857 ) --> 859 sub = _empty_create_subclass( 860 t, outer_size, outer_stride, symbolic_context, callback, source 861 ) 863 # NB: Purposefully guard here to simplify the inner / outer symbols. 864 # Using sym_eq() for symbolic comparison can result in an expression that's too 865 # difficult to guard on, so we use == here. 866 assert sub.shape == outer_size, ( 867 f\"Expected return value from {t.type}__tensor_unflatten__() to have \" 868 f\"shape equal to {outer_size}, but got: {sub.shape}\" 869 ) File ~/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py:855, in MetaConverter.meta_tensor..empty_create_subclass.._empty_create_subclass(t, outer_size, outer_stride, symbolic_context, callback, source) 845 new_empty_tensor = _empty_create_subclass( 846 meta_tensor_desc, 847 meta_tensor_desc.size, (...) 851 current_source, 852 ) 853 inner_tensors[attr] = new_empty_tensor --> 855 return t.type.__tensor_unflatten__( 856 inner_tensors, t.ctx, outer_size, outer_stride 857 ) File ~/flux-stuff/ao/torchao/quantization/autoquant.py:179, in AutoQuantizableLinearWeight.__tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride) 177 weight = tensor_data_dict[\"weight\"] 178 qtensor_class_list, mode, dtype, shape = tensor_attributes[0] --> 179 return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) File ~/flux-stuff/ao/torchao/quantization/autoquant.py:67, in AutoQuantizableLinearWeight.__new__(cls, weight, qtensor_class_list, mode, *args, **kwargs) 65 kwargs[\"requires_grad\"] = False 66 shape = kwargs.pop(\"shape\", weight.shape) ---> 67 return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) InternalTorchDynamoError: TypeError: _make_wrapper_subclass(): argument 'dtype' must be torch.dtype, not torch._C._TensorMeta from user code: File \"/home/mlops/flux-stuff/ao/torchao/quantization/autoquant.py\", line 651, in autoquant_prehook real_model.forward(*args, **kwargs) File \"/home/mlops/flux-stuff/diffusers/src/diffusers/models/transformers/transformer_flux.py\", line 442, in forward hidden_states = self.x_embedder(hidden_states) File \"/home/mlops/miniconda3/envs/flux-fast/lib/python3.11/site-packages/torch/nn/modules/linear.py\", line 125, in forward return F.linear(input, self.weight, self.bias) Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True " } ```

And the following is my env

My env ```shell Collecting environment information... PyTorch version: 2.5.0.dev20240906 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35 Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.3.107 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3 Nvidia driver version: 550.90.07 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 30 On-line CPU(s) list: 0-29 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8462Y+ CPU family: 6 Model: 143 Thread(s) per core: 1 Core(s) per socket: 1 Socket(s): 30 Stepping: 8 BogoMIPS: 5600.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 960 KiB (30 instances) L1i cache: 960 KiB (30 instances) L2 cache: 120 MiB (30 instances) L3 cache: 480 MiB (30 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-29 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Unknown: No mitigations Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled Versions of relevant libraries: [pip3] numpy==2.0.1 [pip3] torch==2.5.0.dev20240906 [pip3] torchao==0.6.0+gitc6abf2bd [pip3] torchaudio==2.5.0.dev20240907 [pip3] torchvision==0.20.0.dev20240907 [pip3] triton==3.0.0 [conda] blas 1.0 mkl [conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly [conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly [conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly [conda] filelock 3.9.0 py311_0 pytorch-nightly [conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly [conda] mkl 2023.1.0 h213fc3f_46344 [conda] mkl-service 2.4.0 py311h5eee18b_1 [conda] mkl_fft 1.3.10 py311h5eee18b_0 [conda] mkl_random 1.2.7 py311ha02d727_0 [conda] mpmath 1.2.1 py311_0 pytorch-nightly [conda] numpy 2.0.1 py311h08b1b3b_1 [conda] numpy-base 2.0.1 py311hf175353_1 [conda] pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly [conda] pysocks 1.7.1 py311_0 pytorch-nightly [conda] pytorch 2.5.0.dev20240906 py3.11_cuda12.4_cudnn9.1.0_0 pytorch-nightly [conda] pytorch-cuda 12.4 hc786d27_7 pytorch-nightly [conda] pytorch-mutex 1.0 cuda pytorch-nightly [conda] requests 2.28.1 py311_0 pytorch-nightly [conda] torchao 0.6.0+gitc6abf2bd dev_0 [conda] torchaudio 2.5.0.dev20240907 py311_cu124 pytorch-nightly [conda] torchtriton 3.0.0+757b6a61e7 py311 pytorch-nightly [conda] torchvision 0.20.0.dev20240907 py311_cu124 pytorch-nightly [conda] urllib3 1.26.14 py311_0 pytorch-nightly ```

When I dont use autoquant the compilation goes through.

Maybe I am missing something trivial or could you point me to the correct torchao commit. Cheers!

sayakpaul commented 2 months ago

What is your GPU? Installing torchao from source helps?

gradjitta commented 2 months ago

I shared my env above and its H100 80GB HBM3 and also I built torchao from source.

And looks like I dont yet see any error when I do the following

quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())

Edit: seems to work with this Takes around 14m 20 secs for the compile and after that its 9.98 it/s from 4.16 it/s

gradjitta commented 2 months ago
Initial output
Initial output without compile
Output after compile and FP8
Output after compile and FP8
sayakpaul commented 2 months ago

Nice.

For autoquant, could you maybe try using the benchmark_image.py script?

gradjitta commented 2 months ago

Oh with the benchmark script it goes through and I am using this

python3 benchmark_image.py --compile --quantization autoquant --batch_size 1

Maybe the difference is the class FluxPipeline vs DiffusionPipeline? (EDIT: I can rerun the benchmark script with FluxPipeline and autotune )

ckpt_id batch_size fuse compile compile_vae quantization sparsify model_memory inference_memory time
black-forest-labs/FLUX.1-dev 1 False True False autoquant False 31.438 32.461 3.407
gradjitta commented 2 months ago

This order results in error

pipeline = FluxPipeline.from_pretrained(PATH_TO_DEV, torch_dtype=torch.bfloat16).to("cuda")
pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)
pipeline.transformer.to(memory_format=torch.channels_last)

pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)

where as this doesnt


pipeline = FluxPipeline.from_pretrained(PATH_TO_DEV, torch_dtype=torch.bfloat16).to("cuda")

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)

pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)

``
sayakpaul commented 2 months ago

Interesting. But for DiffusionPipeline this doesn’t happen?

gradjitta commented 2 months ago

It's the same for both DiffusionPipeline and FluxPipeline. At this point, I think it's the order in which torch.compile is called.

sayakpaul commented 2 months ago

Oh yes, order matters a lot. For autoquant and quantize_() it's different. We should follow the order from benchmark_image.py.

Closing this issue then.

sayakpaul commented 2 months ago

I have also made it clear what the order of autoquant and torch.compile() should be: https://github.com/sayakpaul/diffusers-torchao/commit/ec305cf7c37b7eb52a922cbba9336baf881588ab