I can get a better error message if I set enable_jit_freeze = False in my CompilationConfig:
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 76 but got size 75 for tensor number 1 in the list.
sfast/jit/overrides.py(21): __torch_function__
diffusers/models/unet_2d_blocks.py(2521): forward
diffusers/models/unet_2d_condition.py(1281): forward
sfast/jit/trace_helper.py(89): forward
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
That's why 64x64 behaves differently to 150x157.
It's also why the 150x157 solution generalizes, but the 64x64 solution does not.
s = ''
for name, mod in script_module.named_modules():
if hasattr(mod, 'code'):
s += f'\n{name} {mod.__class__.__name__} code:\n{mod.code}\n'
with open('code.py', 'w') as f:
f.write(s)
Or if you prefer the IR representation:
s = ''
for name, mod in script_module.named_modules():
try:
if hasattr(mod, 'graph'):
s += f'\n{name} {mod.__class__.__name__} graph:\n{mod.graph}\n'
else:
s += f'\n{name} {mod.__class__.__name__} graph:\n(no graph)\n'
except RuntimeError:
s += f'\n{name} {mod.__class__.__name__} graph:\nRuntimeError\n'
with open('graph.txt', 'w') as f:
f.write(s)
There are only a couple of differences in the codegen!
UNet2DCondition#forward passes more arguments when the tensor shape is weird (it's passing upsample_size!)
Left = weird tensor shape
Right = simple tensor shape
The up-blocks only pass the upsample_size argument through if the tensor shape is weird.
Left = weird tensor shape
Right = simple tensor shape
Upsample2D#forward only invokes upsample_nearest2d with a target size if the tensor shape is weird.
Left = weird tensor shape
Right = simple tensor shape
Proposed solution
Maybe it's fine to just require users to trace the model first with a weird tensor shape? So that you always go down the "give nearest_upsample2d a target shape" codepath.
Or modify UNet2DConditionModel#forward to always set forward_upsample_size = True. That would achieve the same outcome more cheaply.
I don't know whether torch.upsample_nearest2d() returns the same when you use scale_factor vs when you use target_size. I'm optimistic that it probably would though, at least for scale_factor=2.
Or maybe this section of UNet2DConditionModel#forward needs to be JITed in script-mode, to enable control flow:
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
Maybe something like this? I haven't tried running this code.
from torch.jit import script_if_tracing
from torch import Tensor
# …
@script_if_tracing
def should_fwd_upsample_size(
sample: Tensor,
default_overall_up_factor: int
) -> bool:
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
return True
return False
forward_upsample_size: bool = should_fwd_upsample_size(sample, default_overall_up_factor)
Problem
If I trace my model like so (weird shape then nice shape), everything is fine:
whereas if I trace my model with the nice shape first, it segfaults (backtrace here):
Note: this problem only reproduces for me on inpainting models (i.e. where conv_in has 8 channels). I don't know why.
Possibly related:
https://github.com/chengzeyi/stable-fast/issues/153#issuecomment-2232237669
Proximal cause
I can get a better error message if I set
enable_jit_freeze = False
in my CompilationConfig:the problem led to this line of code:
https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_blocks.py#L2521
CrossAttnUpBlock2D#forward
The decoder's upsampled hidden_states are not the same size as the encoder's residual hidden states.
Ordinarily (stable-fast disabled), we expect both hidden states to have a height of 75:
But the JIT has caused the decoder's upsampled hidden states to have a height of 76 instead of 75.
Root cause
It turns out stable-diffusion has two upsample algorithms.
Flow control determines the upsample algorithm.
https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/upsampling.py#L167-L171
Upsample2D#forward
Sometimes you do not upsample by 2x.
Sometimes you upsample to a target resolution instead!
"Upsample to a target resolution" happens when the latent height or width is indivisible by 4:
https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_condition.py#L1106-L1113
UNet2DConditionModel#forward
That's why 64x64 behaves differently to 150x157.
It's also why the 150x157 solution generalizes, but the 64x64 solution does not.
Evidence
We can print
ScriptModule#code
to see the difference in the the traced model:https://github.com/chengzeyi/stable-fast/blob/fffe290680ec2ddc01f511e8e7fc62357ed901d8/src/sfast/jit/utils.py#L36
Or if you prefer the IR representation:
There are only a couple of differences in the codegen!
UNet2DCondition#forward
passes more arguments when the tensor shape is weird (it's passingupsample_size
!)Left = weird tensor shape
Right = simple tensor shape
The up-blocks only pass the
upsample_size
argument through if the tensor shape is weird.Left = weird tensor shape
Right = simple tensor shape
Upsample2D#forward
only invokes upsample_nearest2d with a target size if the tensor shape is weird.Left = weird tensor shape
Right = simple tensor shape
Proposed solution
Maybe it's fine to just require users to trace the model first with a weird tensor shape? So that you always go down the "give nearest_upsample2d a target shape" codepath.
Or modify
UNet2DConditionModel#forward
to always setforward_upsample_size = True
. That would achieve the same outcome more cheaply.I don't know whether torch.upsample_nearest2d() returns the same when you use scale_factor vs when you use target_size. I'm optimistic that it probably would though, at least for scale_factor=2.
Or maybe this section of
UNet2DConditionModel#forward
needs to be JITed in script-mode, to enable control flow:https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_condition.py#L1106-L1113
Maybe something like this? I haven't tried running this code.