huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
24.3k stars 5.01k forks source link

traced UNet from documentation fails with RuntimeError: expected scalar type Half but found Float #720

Closed Thomas-MMJ closed 1 year ago

Thomas-MMJ commented 1 year ago

Describe the bug

when running the traced UNet example from

https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx

the error - RuntimeError: expected scalar type Half but found Float

occurs in the section

# warmup
for _ in range(3):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet(*inputs)

Reproduction

import time
import torch
from diffusers import StableDiffusionPipeline
import functools

# torch disable grad
torch.set_grad_enabled(False)

# set variables
n_experiments = 2
unet_runs_per_experiment = 50

# load inputs
def generate_inputs():
    sample = torch.randn(2, 4, 64, 64).half().cuda()
    timestep = torch.rand(1).half().cuda() * 999
    encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
    return sample, timestep, encoder_hidden_states

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    # scheduler=scheduler,
    use_auth_token=True,
    revision="fp16",
    torch_dtype=torch.float16,
).to("cuda")
unet = pipe.unet
unet.eval()
unet.to(memory_format=torch.channels_last)  # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False)  # set return_dict=False as default

# warmup
for _ in range(3):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet(*inputs)

Logs

RuntimeError                              Traceback (most recent call last)
Cell In [1], line 37
     35     with torch.inference_mode():
     36         inputs = generate_inputs()
---> 37         orig_output = unet(*inputs)
     39 # trace
     40 print("tracing..")

File c:\users\username\mambaforge\envs\ldm\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File c:\users\username\mambaforge\envs\ldm\lib\site-packages\diffusers\models\unet_2d_condition.py:225, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, return_dict)
    222 timesteps = timesteps.expand(sample.shape[0])
    224 t_emb = self.time_proj(timesteps)
--> 225 emb = self.time_embedding(t_emb)
    227 # 2. pre-process
    228 sample = self.conv_in(sample)

File c:\users\username\mambaforge\envs\ldm\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File c:\users\username\mambaforge\envs\ldm\lib\site-packages\diffusers\models\embeddings.py:73, in TimestepEmbedding.forward(self, sample)
     72 def forward(self, sample):
---> 73     sample = self.linear_1(sample)
     75     if self.act is not None:
     76         sample = self.act(sample)

File c:\users\username\mambaforge\envs\ldm\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File c:\users\username\mambaforge\envs\ldm\lib\site-packages\torch\nn\modules\linear.py:114, in Linear.forward(self, input)    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: expected scalar type Half but found Float

System Info

Thomas-MMJ commented 1 year ago

using the development version it works, so this is for the unleased version

Marcophono2 commented 1 year ago

Hi @Thomas-MMJ ! Just curious: What do you mean with "development version"?

Best regards Marc

Thomas-MMJ commented 1 year ago

It means when i checked out from head, vs using a pip install version it worked. That said, it no longer works because now

Hi @Thomas-MMJ ! Just curious: What do you mean with "development version"?

Best regards Marc

I meant when i built from source code from head (the main development branch) it worked, but not on the version installed via pip install diffusers.

That said, it no longer works because they are now using a frozendict which prevents substitution of the traced unet.