facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
931 stars 67 forks source link

AttributeError: 'ToMeVisionTransformer' object has no attribute 'blocks' #40

Open Guoli-w opened 3 months ago

Guoli-w commented 3 months ago

pls help me! 555 thx!!

dbolya commented 3 months ago

Hello, I need more information than this. Can you share the snippet of code you're trying to run?

Guoli-w commented 3 months ago

Foreword: Thank you very much for the reply. Please forgive me for my rough description, as I was a bit of a crash when I was tormented by bugs. Body: I have two understandings of how to apply TOME to my own model:1. Follow the example in tome.patch.timm, and add the sample code to the block in your model and the forward in attention. 2. You only need to change the inheritance of blocks and attention in your model to Block, Attention and VisionTransformer introduced from timm.model.vit, and then add only a sentence tome.patch.timm(model, trace_source=True) when creating a model. I may not be able to express it clearly, but below is the rough structure of my code, and I have marked some questions for you to ask. At last: Looking forward to your reply and thank you for your time. Best regards. image


from typing import Tuple
from timm.models.vision_transformer import Attention, Block, VisionTransformer
from tome.merge import bipartite_soft_matching, merge_source, merge_wavg
from tome.utils import parse_r

# class trueatt(Attention):   ????
class trueatt(nn.Module):   #????   which one should i choose?
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self,q,k,v):
        B,d,H,W = q.shape
        q,k,v =****
        return x

# class myattention(Attention):   ???? which one should i choose?
class myattention(nn.Module):   # ???? which one should i choose?
    def __init__(self,**kwargs):
        super().__init__()

        self.trueatt = nn.ModuleList(
            [trueAttention(**kawrgs)])

    def forward(self, x):
        B, H, W, C = x.shape
        qkv = self.qkv(x).reshape(2, 1, 0, 3, 4, 5)

        # i need take the example which in tome.patch.timm.py to here ???
        x = self.trueatt(qkv[0], qkv[1], qkv[2])

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# class DilateBlock(nn.Module):
class DilateBlock(Block):  ????
    def __init__(****kwargs):
        super().__init__()

        self.attn = myattention(****kwargs)

    def forward(self, x):
        x = x + self.pos_embed(x)
        x = x + self.drop_path(self.attn(self.norm1(x)))

        return x

class mystage1(nn.Module):
    def __init__(self, **kwargs):

        super().__init__()
        self.blocks = nn.ModuleList([
            myblock(**kwargs)])

    def forward(self, x):
        ****
        return x

# class localvit(nn.Module):
class localvit(VisionTransformer):   # ???? which one should i choose?
    def __init__(self, **kwargs):
                 ):
        super().__init__()

        self.patch_embed = PatchEmbed()
        self.stages = nn.ModuleList()
            stage = module()
    def forward(self, x):
        return x

@register_model
def localvit_tiny(pretrained=True, **kwargs):
    model = localvit(**kwargs)
    model.default_cfg = _cfg()
    # way 2:  The Transformer model I wrote myself does not need to be modified, can I use TOME just by filling in this line?
    # tome.patch.timm(model, trace_source=True)

    return model
dbolya commented 3 months ago

Hi, sorry for the delay. Here's the answers to your questions as far as I can tell:

  1. For whether you should inherit nn.Module or one of Block/Attention, etc., that depends on if you reuse code / functions from those modules. From what I can tell, you intend to reimplement the whole transformer, right? You're not using any of the default behavior of those modules, so you don't need to inherit them.
  2. I would try first implementing ToMe without proportional attention (this is the part that modifies the attention module). Once you have basic ToMe working, then you can try modifying attention. This way, you can debug each part independently. Thus, I suggest you only modify the transformer part and the block part.
  3. Do not use tome.patch.timm directly on your model. The timm patch assumes timm modules, so it will replace your modules with versions of the timm ones (thus overwriting any modifications you make).
  4. Since you're the one implementing ToMe and the model itself, you don't need to create a "patch". The patches are only there because I don't own the code for what I'm patching, so I need to edit the code at runtime. In your case, I presume you can just bake ToMe directly into the transformer itself.

Thus, for the most basic ToMe, all you need to implement is this:

  1. In localvit's forward function, initialize some variables:

    # localvit
    def forward(self, x):
    size = None
    r = **** # Fill this in either at initialization or pull from something like `self.r`.
    
    # Then pass these parameters into your blocks (through your stages, not pictured here)
    for block in self.blocks:
        x, size = block(x, size, r)
  2. In your blocks, consume these parameters and apply ToMe after attention:

    # myblock
    def forward(self, size, r):
    x = x + self.pos_embed(x)
    xa, k = self.attn(self.norm1(x))  # Make sure your attn module returns the mean of k over the heads (e.g., k.mean(1))
    x = x + self.drop_path(xa)
    
    # Apply ToMe after attention
    if r > 0:
        merge, _ = bipartite_soft_matching(k, r)  # Pass in class_token=True if your model has a class token
        x, size = merge_wavg(merge, x, size)
    
    # Rest of the block
    ****
    
    return x, size

One caveat is that ToMe expects the tensors x and k to have a shape [batch, tokens, features], so you'll need to reshape / permute them if it's not in that order. This also assumes you have no other pooling operations.