chengzeyi / stable-fast

Best inference performance optimization framework for HuggingFace Diffusers on NVIDIA GPUs.
MIT License
1.19k stars 73 forks source link

segfault when latent height or width is not divisible by 4 #159

Open Birch-san opened 3 months ago

Birch-san commented 3 months ago

Problem

If I trace my model like so (weird shape then nice shape), everything is fine:

model(randn(1, 4, 150, 157))
model(randn(1, 4, 64, 64))

whereas if I trace my model with the nice shape first, it segfaults (backtrace here):

model(randn(1, 4, 64, 64))
model(randn(1, 4, 150, 157))

Note: this problem only reproduces for me on inpainting models (i.e. where conv_in has 8 channels). I don't know why.

Possibly related:
https://github.com/chengzeyi/stable-fast/issues/153#issuecomment-2232237669

Proximal cause

I can get a better error message if I set enable_jit_freeze = False in my CompilationConfig:

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 76 but got size 75 for tensor number 1 in the list.
sfast/jit/overrides.py(21): __torch_function__
diffusers/models/unet_2d_blocks.py(2521): forward
diffusers/models/unet_2d_condition.py(1281): forward
sfast/jit/trace_helper.py(89): forward

the problem led to this line of code:
https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_blocks.py#L2521

‎CrossAttnUpBlock2D#forward

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

The decoder's upsampled hidden_states are not the same size as the encoder's residual hidden states.

Ordinarily (stable-fast disabled), we expect both hidden states to have a height of 75:

    hidden_states = torch.Size([2, 1280, 75, 79])
res_hidden_states = torch.Size([2, 640, 75, 79])

    hidden_states = torch.Size([2, 640, 75, 79])
res_hidden_states = torch.Size([2, 640, 75, 79])

    hidden_states = torch.Size([2, 640, 75, 79])
res_hidden_states = torch.Size([2, 320, 75, 79])

But the JIT has caused the decoder's upsampled hidden states to have a height of 76 instead of 75.

Root cause

It turns out stable-diffusion has two upsample algorithms.
Flow control determines the upsample algorithm.

https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/upsampling.py#L167-L171

Upsample2D#forward

if self.interpolate:
    if output_size is None:
        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
    else:
        hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

Sometimes you do not upsample by 2x.
Sometimes you upsample to a target resolution instead!

"Upsample to a target resolution" happens when the latent height or width is indivisible by 4:

https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_condition.py#L1106-L1113

UNet2DConditionModel#forward

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

for dim in sample.shape[-2:]:
    if dim % default_overall_up_factor != 0:
        # Forward upsample size to force interpolation output size.
        forward_upsample_size = True
        break

That's why 64x64 behaves differently to 150x157.
It's also why the 150x157 solution generalizes, but the 64x64 solution does not.

Evidence

We can print ScriptModule#code to see the difference in the the traced model:
https://github.com/chengzeyi/stable-fast/blob/fffe290680ec2ddc01f511e8e7fc62357ed901d8/src/sfast/jit/utils.py#L36

s = ''
for name, mod in script_module.named_modules():
    if hasattr(mod, 'code'):
        s += f'\n{name} {mod.__class__.__name__} code:\n{mod.code}\n'
with open('code.py', 'w') as f:
    f.write(s)

Or if you prefer the IR representation:

s = ''
for name, mod in script_module.named_modules():
    try:
        if hasattr(mod, 'graph'):
            s += f'\n{name} {mod.__class__.__name__} graph:\n{mod.graph}\n'
        else:
            s += f'\n{name} {mod.__class__.__name__} graph:\n(no graph)\n'
    except RuntimeError:
        s += f'\n{name} {mod.__class__.__name__} graph:\nRuntimeError\n'
with open('graph.txt', 'w') as f:
    f.write(s)

There are only a couple of differences in the codegen!

UNet2DCondition#forward passes more arguments when the tensor shape is weird (it's passing upsample_size!)
Left = weird tensor shape
Right = simple tensor shape

image

The up-blocks only pass the upsample_size argument through if the tensor shape is weird.
Left = weird tensor shape
Right = simple tensor shape

image image

Upsample2D#forward only invokes upsample_nearest2d with a target size if the tensor shape is weird.
Left = weird tensor shape
Right = simple tensor shape

image

Proposed solution

Maybe it's fine to just require users to trace the model first with a weird tensor shape? So that you always go down the "give nearest_upsample2d a target shape" codepath.

Or modify UNet2DConditionModel#forward to always set forward_upsample_size = True. That would achieve the same outcome more cheaply.

I don't know whether torch.upsample_nearest2d() returns the same when you use scale_factor vs when you use target_size. I'm optimistic that it probably would though, at least for scale_factor=2.

Or maybe this section of UNet2DConditionModel#forward needs to be JITed in script-mode, to enable control flow:

https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_condition.py#L1106-L1113

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

for dim in sample.shape[-2:]:
    if dim % default_overall_up_factor != 0:
        # Forward upsample size to force interpolation output size.
        forward_upsample_size = True
        break

Maybe something like this? I haven't tried running this code.

from torch.jit import script_if_tracing
from torch import Tensor

# …

@script_if_tracing
def should_fwd_upsample_size(
    sample: Tensor,
    default_overall_up_factor: int
) -> bool:
    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
    for dim in sample.shape[-2:]:
        if dim % default_overall_up_factor != 0:
            # Forward upsample size to force interpolation output size.
            return True
    return False

forward_upsample_size: bool = should_fwd_upsample_size(sample, default_overall_up_factor)