Open dru10 opened 1 month ago
After altering the code from torch export (which probably wasn't necessary)
For reference, this is the modification I did inside py/torch_tensorrt/dynamo/_tracer.py#L81
exp_program = export(mod, tuple(torch_inputs), kwargs=kwargs, dynamic_shapes=tuple(dynamic_shapes))
And this is the traceback
Traceback (most recent call last):
File "/workspace/torch-tensorrt/src/dummy.py", line 21, in <module>
trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 248, in compile
exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
exp_program = export(mod, tuple(torch_inputs), kwargs=kwargs, dynamic_shapes=tuple(dynamic_shapes))
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export
return _export(
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper
raise e
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper
ep = fn(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper
return fn(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export
gm_torch_level = _export_to_torch_ir(
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner
result_traced = opt_f(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1272, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
return super().call_function(tx, args, kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
tracer.run()
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1644, in CONTAINS_OP
self.push(right.call_method(self, "__contains__", [left], {}))
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/constant.py", line 182, in call_method
result = search in self.value
torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable
from user code:
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1162, in forward
aug_emb = self.get_aug_embed(
File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 973, in get_aug_embed
if "text_embeds" not in added_cond_kwargs:
I have the same question. have you solved the problem?
❓ Question
How do you save a unet model compiled Torch-TensorRT from Stable Diffusion XL?
What you have already tried
I've tried following the compilation instructions from the tutorial (link). It wasn't very useful for my use case because I would like to save the compilation on disk and load it down the line when inference is needed.
So I've tried following the instructions which let you save your compilation using the dynamo backend (link). This script represents a summary of what I'm doing:
But this yields the following error:
TypeError: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'
So, I've tried to provide these arguments as well, found after some playing around with the code from diffusers:
And I get the same error. Probably, the kwargs don't get passed down into the calling functions. After altering the code from torch export (which probably wasn't necessary), I got an error of the type:
torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable
Any ideas how to properly compile a unet model from stable diffusion XL? Many thanks in advance.
Environment
conda
,pip
,libtorch
, source):pip install torch --index-url https://download.pytorch.org/whl/cu121
Additional context