pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.49k stars 344 forks source link

❓ [Question] How do you save a unet model compiled Torch-TensorRT (Stable Diffusion XL) #3018

Open dru10 opened 1 month ago

dru10 commented 1 month ago

❓ Question

How do you save a unet model compiled Torch-TensorRT from Stable Diffusion XL?

What you have already tried

I've tried following the compilation instructions from the tutorial (link). It wasn't very useful for my use case because I would like to save the compilation on disk and load it down the line when inference is needed.

So I've tried following the instructions which let you save your compilation using the dynamo backend (link). This script represents a summary of what I'm doing:

import torch
import torch_tensorrt
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")

inputs = [torch.randn((2, 4, 128, 128)).cuda()]  # After some digging, these are the input sizes needed to generate 1024x1024 images

trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs)

But this yields the following error: TypeError: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'

So, I've tried to provide these arguments as well, found after some playing around with the code from diffusers:

kwargs = {
    "timestep": torch.tensor(951.0).cuda(),
    "encoder_hidden_states": torch.randn(
        (2, 77, 2048), dtype=torch.float16
    ).cuda(),
}

trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs, **kwargs)

And I get the same error. Probably, the kwargs don't get passed down into the calling functions. After altering the code from torch export (which probably wasn't necessary), I got an error of the type: torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable

Any ideas how to properly compile a unet model from stable diffusion XL? Many thanks in advance.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

dru10 commented 1 month ago

After altering the code from torch export (which probably wasn't necessary)

For reference, this is the modification I did inside py/torch_tensorrt/dynamo/_tracer.py#L81

exp_program = export(mod, tuple(torch_inputs), kwargs=kwargs, dynamic_shapes=tuple(dynamic_shapes))

And this is the traceback

Traceback (most recent call last):
  File "/workspace/torch-tensorrt/src/dummy.py", line 21, in <module>
    trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 248, in compile
    exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
    exp_program = export(mod, tuple(torch_inputs), kwargs=kwargs, dynamic_shapes=tuple(dynamic_shapes))
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper
    raise e
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper
    ep = fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper
    return fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export
    gm_torch_level = _export_to_torch_ir(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1272, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1644, in CONTAINS_OP
    self.push(right.call_method(self, "__contains__", [left], {}))
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/constant.py", line 182, in call_method
    result = search in self.value
torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable

from user code:
   File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1162, in forward
    aug_emb = self.get_aug_embed(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 973, in get_aug_embed
    if "text_embeds" not in added_cond_kwargs:
pangyoki commented 3 days ago

I have the same question. have you solved the problem?