facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
931 stars 67 forks source link

ToMeBlock cannot be used with torch.utils.checkpoint #8

Closed powermano closed 1 year ago

powermano commented 1 year ago

When using relatively smaller VIT model like VIT_TI, we do not need to using torch.utils.checkpoint. But for VIT_L or VIT_H,it is necessary to use torch.utils.checkpoint for saving a lot of GPU memory.

the forward_features in the origin timm version VisionTransformer is as following:

def forward_features(self, x):
    x = self.patch_embed(x)
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    if self.dist_token is None:
        x = torch.cat((cls_token, x), dim=1)
    else:
        x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
    x = self.pos_drop(x + self.pos_embed)
    x = self.blocks(x)
    x = self.norm(x)
    if self.dist_token is None:
        return self.pre_logits(x[:, 0])
    else:
        return x[:, 0], x[:, 1]

Then we add torch.utils.checkpoint to it as following, but it caused errors when loss.backward().

def forward_features(self, x):
    x = self.patch_embed(x)
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    if self.dist_token is None:
        x = torch.cat((cls_token, x), dim=1)
    else:
        x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
    x = self.pos_drop(x + self.pos_embed)
 # x = self.blocks(x)
    for func in self.blocks:
            if self.using_checkpoint and self.training:
                from torch.utils.checkpoint import checkpoint
                x = checkpoint(func, x)
            else:
                x = func(x)
    x = self.norm(x)
    if self.dist_token is None:
        return self.pre_logits(x[:, 0])
    else:
        return x[:, 0], x[:, 1]

The error is as following. I have verified the error is caused by torch.utils.checkpoint.

 File "/mnt/my_dist/code/classification/backbones/vit_tome.py", line 417, in forward
    attn = attn + size.log()[:, None, None, :, 0]
RuntimeError: The size of tensor a (29) must match the size of tensor b (24) at non-singleton dimension 3

Can anyone help with some solutions?Thanks in advance.

dbolya commented 1 year ago

Hmm, somehow it looks like the token aren't getting reduced in a block. Can you debug why that's the case?

RuntimeError: The size of tensor a (29) must match the size of tensor b (24) at non-singleton dimension 3

I'm assuming this means attn has 29 tokens at dim 3, while size has 24 tokens. These should always have matching numbers of tokens, because the tokens are reduced in x at the same time size is calculated. So perhaps checkpointing is saving the wrong version of the features?

I'm not entirely sure how checkpointing works internally, but is there a way to disable it on the ToMe part?

powermano commented 1 year ago

Checkpointing will discard some internal activations, possibly correct parameters(attn has 24 tokens )were discarded. When backpropagation,Checkpointing needs to recalculate the correct features (with 24 tokens).

The input of the last layer is a feature with 29 tokens, first through ToMeAttention, and then merge tokens. The output at this time is a feature of 24 tokens and a size has 24 tokens.

And i can not find a way to disable it on the ToMe part

powermano commented 1 year ago

We can reduce batch_size to save GPU memory instead of Checkpointing, and then use gradient accumulation to ensure convergence.

boheumd commented 1 year ago

@powermano @dbolya Hello, did you figure out a way to train with ToMe while using the torch.utils.checkpoint?