dbolya / tomesd

Speed up Stable Diffusion with this one simple trick!
MIT License
1.29k stars 78 forks source link

Applying ToMe to the relatively low-resolution images #32

Closed taehong-moon closed 1 year ago

taehong-moon commented 1 year ago

Thank you for providing the detailed implementation of ToMe. I applied your method to the DiT-XL/2 model trained on the ImageNet 256x256 dataset (available at https://github.com/facebookresearch/DiT) to verify its effectiveness.

However, I observed that the performance of ToMe is limited when applied to images with a resolution of 256 compared to the results obtained for higher resolutions such as 512. Although setting max_downsample to 2 resulted in faster inference, I did not observe any significant differences between the ratio values.

Therefore, I am curious whether I implemented ToMe incorrectly or if it is indeed true that ToMe's performance is relatively marginal when compared to high-resolution images (above 512).

In order to apply ToMe to the DiT model, I only changed the small portion of your code as follows:

def apply_patch(
        model: torch.nn.Module,
        ratio: float = 0.5,
        max_downsample: int = 1,
        sx: int = 2, sy: int = 2,
        use_rand: bool = True,
        merge_attn: bool = True,
        merge_crossattn: bool = False,
        merge_mlp: bool = False):
    """
    Patches a DiT model with ToMe.
    Apply this to the highest level stable diffusion object (i.e., it should have a .model.diffusion_model).

    Important Args:
     - model: A top level DiT model module to patch in place. Should have a ".model.diffusion_model"
     - ratio: The ratio of tokens to merge. I.e., 0.4 would reduce the total number of tokens by 40%.
              The maximum value for this is 1-(1/(sx*sy)). By default, the max is 0.75 (I recommend <= 0.5 though).
              Higher values result in more speed-up, but with more visual quality loss.

    Args to tinker with if you want:
     - max_downsample [1, 2, 4, or 8]: Apply ToMe to layers with at most this amount of downsampling.
                                       E.g., 1 only applies to layers with no downsampling (4/15) while
                                       8 applies to all layers (15/15). I recommend a value of 1 or 2.
     - sx, sy: The stride for computing dst sets (see paper). A higher stride means you can merge more tokens,
               but the default of (2, 2) works well in most cases. Doesn't have to divide image size.
     - use_rand: Whether or not to allow random perturbations when computing dst sets (see paper). Usually
                 you'd want to leave this on, but if you're having weird artifacts try turning this off.
     - merge_attn: Whether or not to merge tokens for attention (recommended).
     - merge_crossattn: Whether or not to merge tokens for cross attention (not recommended).
     - merge_mlp: Whether or not to merge tokens for the mlp layers (very not recommended).
    """

    # Make sure the module is not currently patched
    remove_patch(model)

    diffusion_model = model

    # TODO Preserve this part!
    diffusion_model._tome_info = {
        "size": (32, 32),
        "hooks": [],
        "args": {
            "ratio": ratio,
            "max_downsample": max_downsample,
            "sx": sx, "sy": sy,
            "use_rand": use_rand,
            "generator": None,
            "merge_attn": merge_attn,
            "merge_crossattn": merge_crossattn,
            "merge_mlp": merge_mlp
        }
    }

    # TODO Preserve this part!
    hook_tome_model(diffusion_model)

    for _, module in diffusion_model.named_modules():
        # If for some reason this has a different name, create an issue and I'll fix it
        if isinstance_str(module, "DiTBlock"):
            make_tome_block_fn = make_dit_tome_block
            module.__class__ = make_tome_block_fn(module.__class__)
            module._tome_info = diffusion_model._tome_info

    return model
def make_dit_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
    """
    Make a patched class for a diffusers model.
    This patch applies ToMe to the forward function of the block.
    """
    class ToMeBlock(block_class):
        # Save for unpatching later
        _parent = block_class

        def forward(
            self,
            x,
            c
        ):
            # (1) ToMe
            m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(x, self._tome_info)

            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
            x = gate_msa.unsqueeze(1) * u_a(self.attn(modulate(m_a(self.norm1(x)), shift_msa, scale_msa))) + x
            x = gate_mlp.unsqueeze(1) * u_m(self.mlp(modulate(m_m(self.norm2(x)), shift_mlp, scale_mlp))) + x

            return x

    return ToMeBlock