pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.09k stars 6.94k forks source link

Feature extraction in torchvision.models.vit_b_16 #5718

Open DavidTorpey opened 2 years ago

DavidTorpey commented 2 years ago

🐛 Describe the bug

Hi

It’s easy enough to obtain output features from the CNNs in torchvision.models by doing this:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.resnet18()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))

However, when I attempt to do this with torchvision.models.vit_b_16:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.vit_b_16()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))

I get the following error:

AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 14, 14])

Any help would be greatly appreciated.

Versions

Torch version: 1.11.0+cu102 Torchvision version: 0.12.0+cu102

cc @datumbox

datumbox commented 2 years ago

@DavidTorpey thanks a lot for reporting!

@alexander-soare I wonder if you have the bandwidth to have a look?

alexander-soare commented 2 years ago

@DavidTorpey the problem is that not all of what goes on in the forward method of the VIT model is implemented with submodules. There are some functional transformations.

Specifically in your case, that assertion error is being raised within this call: https://github.com/pytorch/vision/blob/3925946f994a4b779ab9286654e7011bb175a70a/torchvision/models/vision_transformer.py#L267

But if you look at the preceding lines you will see there are transformations on x that do not use submodules. Indeed, if you run print([n for n, _ in model.named_children()]), you get ['conv_proj', 'encoder', 'heads']. So you can see that if you only use submodules, you're missing some steps prior to the encoder.

The way you handle this is by using TorchVision's FX based feature extraction. Under the hood, this traces through everything that happens in the forward method (submodules and functional transforms). You can read a whole lot more about this here https://pytorch.org/blog/FX-feature-extraction-torchvision/

So long story short, this is how you ought to do it:

import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

model = models.vit_b_16()

# These two lines not needed but, you would use them to work out which node you want
# _, eval_nodes = get_graph_node_names(model)
# print(eval_nodes)

feature_extractor = create_feature_extractor(model, return_nodes=['encoder'])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))
print(output_features['encoder'].shape)

cc @datumbox therefore I do not think this is a bug

DavidTorpey commented 2 years ago

@alexander-soare Thanks a lot for your comprehensive answer. However, I did attempt to do this with create_feature_extractor as you've done, but the resulting size from your example is torch.Size([1, 197, 768]), but I would expect it to be torch.Size([1, 768]), since the final layer that feeds into the final prediction MLP (heads) as in_features=768.

Are you able to shed some light on where the 197 is coming from, and how to get a shape of (1, 768)?

Thanks

datumbox commented 2 years ago

@alexander-soare Thanks for dedicating the time to do a deep dive. I understand that all the preprocessing that happens on forward() outside of a module is the problem here. Perhaps it's worth adopting on the future the approach of structuring the top-level models as a series of submodules. Actually the majority of models behave like this already. Thoughts?

alexander-soare commented 2 years ago

@DavidTorpey if you look at the output of get_graph_node_names you will find the last 3 nodes are 'encoder.ln', 'getitem_5', 'heads.head']. You want the getitem_5 node which does this https://github.com/pytorch/vision/blob/3925946f994a4b779ab9286654e7011bb175a70a/torchvision/models/vision_transformer.py#L270. It's selecting the first of the 197 in your torch.Size([1, 197, 768]) (ie the "class token").

alexander-soare commented 2 years ago

@datumbox before working on FX I was trying to structure things that way in some timm models and found myself making weird modules that only made sense in a particular context. For instance, in the VIT example you would end up wrapping the final slice operation in a module... I also tried the approach of absorbing things into modules (like absorbing the slice operation into the encoder module). Sometimes this works, sometimes you find yourself changing the module name and docstring to try to make the semantics appropriate to the new inputs/outputs of the module. A lot of the time, the detour just doesn't feel worth it.

If the FX solution didn't exist, I'd probably be in agreement that that's the way to go. But with FX, problem solved no?

DavidTorpey commented 2 years ago

@alexander-soare Thank you! Makes sense