Open Guoli-w opened 3 months ago
Hello, I need more information than this. Can you share the snippet of code you're trying to run?
Foreword: Thank you very much for the reply. Please forgive me for my rough description, as I was a bit of a crash when I was tormented by bugs. Body: I have two understandings of how to apply TOME to my own model:1. Follow the example in tome.patch.timm, and add the sample code to the block in your model and the forward in attention. 2. You only need to change the inheritance of blocks and attention in your model to Block, Attention and VisionTransformer introduced from timm.model.vit, and then add only a sentence tome.patch.timm(model, trace_source=True) when creating a model. I may not be able to express it clearly, but below is the rough structure of my code, and I have marked some questions for you to ask. At last: Looking forward to your reply and thank you for your time. Best regards.
from typing import Tuple
from timm.models.vision_transformer import Attention, Block, VisionTransformer
from tome.merge import bipartite_soft_matching, merge_source, merge_wavg
from tome.utils import parse_r
# class trueatt(Attention): ????
class trueatt(nn.Module): #???? which one should i choose?
def __init__(self, **kwargs):
super().__init__()
def forward(self,q,k,v):
B,d,H,W = q.shape
q,k,v =****
return x
# class myattention(Attention): ???? which one should i choose?
class myattention(nn.Module): # ???? which one should i choose?
def __init__(self,**kwargs):
super().__init__()
self.trueatt = nn.ModuleList(
[trueAttention(**kawrgs)])
def forward(self, x):
B, H, W, C = x.shape
qkv = self.qkv(x).reshape(2, 1, 0, 3, 4, 5)
# i need take the example which in tome.patch.timm.py to here ???
x = self.trueatt(qkv[0], qkv[1], qkv[2])
x = self.proj(x)
x = self.proj_drop(x)
return x
# class DilateBlock(nn.Module):
class DilateBlock(Block): ????
def __init__(****kwargs):
super().__init__()
self.attn = myattention(****kwargs)
def forward(self, x):
x = x + self.pos_embed(x)
x = x + self.drop_path(self.attn(self.norm1(x)))
return x
class mystage1(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.blocks = nn.ModuleList([
myblock(**kwargs)])
def forward(self, x):
****
return x
# class localvit(nn.Module):
class localvit(VisionTransformer): # ???? which one should i choose?
def __init__(self, **kwargs):
):
super().__init__()
self.patch_embed = PatchEmbed()
self.stages = nn.ModuleList()
stage = module()
def forward(self, x):
return x
@register_model
def localvit_tiny(pretrained=True, **kwargs):
model = localvit(**kwargs)
model.default_cfg = _cfg()
# way 2: The Transformer model I wrote myself does not need to be modified, can I use TOME just by filling in this line?
# tome.patch.timm(model, trace_source=True)
return model
Hi, sorry for the delay. Here's the answers to your questions as far as I can tell:
tome.patch.timm
directly on your model. The timm patch assumes timm modules, so it will replace your modules with versions of the timm ones (thus overwriting any modifications you make).Thus, for the most basic ToMe, all you need to implement is this:
In localvit's forward function, initialize some variables:
# localvit
def forward(self, x):
size = None
r = **** # Fill this in either at initialization or pull from something like `self.r`.
# Then pass these parameters into your blocks (through your stages, not pictured here)
for block in self.blocks:
x, size = block(x, size, r)
In your blocks, consume these parameters and apply ToMe after attention:
# myblock
def forward(self, size, r):
x = x + self.pos_embed(x)
xa, k = self.attn(self.norm1(x)) # Make sure your attn module returns the mean of k over the heads (e.g., k.mean(1))
x = x + self.drop_path(xa)
# Apply ToMe after attention
if r > 0:
merge, _ = bipartite_soft_matching(k, r) # Pass in class_token=True if your model has a class token
x, size = merge_wavg(merge, x, size)
# Rest of the block
****
return x, size
One caveat is that ToMe expects the tensors x and k to have a shape [batch, tokens, features]
, so you'll need to reshape / permute them if it's not in that order. This also assumes you have no other pooling operations.
pls help me! 555 thx!!