Thank you for providing the detailed implementation of ToMe. I applied your method to the DiT-XL/2 model trained on the ImageNet 256x256 dataset (available at https://github.com/facebookresearch/DiT) to verify its effectiveness.
However, I observed that the performance of ToMe is limited when applied to images with a resolution of 256 compared to the results obtained for higher resolutions such as 512. Although setting max_downsample to 2 resulted in faster inference, I did not observe any significant differences between the ratio values.
Therefore, I am curious whether I implemented ToMe incorrectly or if it is indeed true that ToMe's performance is relatively marginal when compared to high-resolution images (above 512).
In order to apply ToMe to the DiT model, I only changed the small portion of your code as follows:
def apply_patch(
model: torch.nn.Module,
ratio: float = 0.5,
max_downsample: int = 1,
sx: int = 2, sy: int = 2,
use_rand: bool = True,
merge_attn: bool = True,
merge_crossattn: bool = False,
merge_mlp: bool = False):
"""
Patches a DiT model with ToMe.
Apply this to the highest level stable diffusion object (i.e., it should have a .model.diffusion_model).
Important Args:
- model: A top level DiT model module to patch in place. Should have a ".model.diffusion_model"
- ratio: The ratio of tokens to merge. I.e., 0.4 would reduce the total number of tokens by 40%.
The maximum value for this is 1-(1/(sx*sy)). By default, the max is 0.75 (I recommend <= 0.5 though).
Higher values result in more speed-up, but with more visual quality loss.
Args to tinker with if you want:
- max_downsample [1, 2, 4, or 8]: Apply ToMe to layers with at most this amount of downsampling.
E.g., 1 only applies to layers with no downsampling (4/15) while
8 applies to all layers (15/15). I recommend a value of 1 or 2.
- sx, sy: The stride for computing dst sets (see paper). A higher stride means you can merge more tokens,
but the default of (2, 2) works well in most cases. Doesn't have to divide image size.
- use_rand: Whether or not to allow random perturbations when computing dst sets (see paper). Usually
you'd want to leave this on, but if you're having weird artifacts try turning this off.
- merge_attn: Whether or not to merge tokens for attention (recommended).
- merge_crossattn: Whether or not to merge tokens for cross attention (not recommended).
- merge_mlp: Whether or not to merge tokens for the mlp layers (very not recommended).
"""
# Make sure the module is not currently patched
remove_patch(model)
diffusion_model = model
# TODO Preserve this part!
diffusion_model._tome_info = {
"size": (32, 32),
"hooks": [],
"args": {
"ratio": ratio,
"max_downsample": max_downsample,
"sx": sx, "sy": sy,
"use_rand": use_rand,
"generator": None,
"merge_attn": merge_attn,
"merge_crossattn": merge_crossattn,
"merge_mlp": merge_mlp
}
}
# TODO Preserve this part!
hook_tome_model(diffusion_model)
for _, module in diffusion_model.named_modules():
# If for some reason this has a different name, create an issue and I'll fix it
if isinstance_str(module, "DiTBlock"):
make_tome_block_fn = make_dit_tome_block
module.__class__ = make_tome_block_fn(module.__class__)
module._tome_info = diffusion_model._tome_info
return model
def make_dit_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class for a diffusers model.
This patch applies ToMe to the forward function of the block.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
def forward(
self,
x,
c
):
# (1) ToMe
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(x, self._tome_info)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = gate_msa.unsqueeze(1) * u_a(self.attn(modulate(m_a(self.norm1(x)), shift_msa, scale_msa))) + x
x = gate_mlp.unsqueeze(1) * u_m(self.mlp(modulate(m_m(self.norm2(x)), shift_mlp, scale_mlp))) + x
return x
return ToMeBlock
Thank you for providing the detailed implementation of ToMe. I applied your method to the DiT-XL/2 model trained on the ImageNet 256x256 dataset (available at https://github.com/facebookresearch/DiT) to verify its effectiveness.
However, I observed that the performance of ToMe is limited when applied to images with a resolution of 256 compared to the results obtained for higher resolutions such as 512. Although setting max_downsample to 2 resulted in faster inference, I did not observe any significant differences between the ratio values.
Therefore, I am curious whether I implemented ToMe incorrectly or if it is indeed true that ToMe's performance is relatively marginal when compared to high-resolution images (above 512).
In order to apply ToMe to the DiT model, I only changed the small portion of your code as follows: