Open DavidTorpey opened 2 years ago
@DavidTorpey thanks a lot for reporting!
@alexander-soare I wonder if you have the bandwidth to have a look?
@DavidTorpey the problem is that not all of what goes on in the forward
method of the VIT model is implemented with submodules. There are some functional transformations.
Specifically in your case, that assertion error is being raised within this call: https://github.com/pytorch/vision/blob/3925946f994a4b779ab9286654e7011bb175a70a/torchvision/models/vision_transformer.py#L267
But if you look at the preceding lines you will see there are transformations on x
that do not use submodules. Indeed, if you run print([n for n, _ in model.named_children()])
, you get ['conv_proj', 'encoder', 'heads']
. So you can see that if you only use submodules, you're missing some steps prior to the encoder
.
The way you handle this is by using TorchVision's FX based feature extraction. Under the hood, this traces through everything that happens in the forward
method (submodules and functional transforms). You can read a whole lot more about this here https://pytorch.org/blog/FX-feature-extraction-torchvision/
So long story short, this is how you ought to do it:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
model = models.vit_b_16()
# These two lines not needed but, you would use them to work out which node you want
# _, eval_nodes = get_graph_node_names(model)
# print(eval_nodes)
feature_extractor = create_feature_extractor(model, return_nodes=['encoder'])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))
print(output_features['encoder'].shape)
cc @datumbox therefore I do not think this is a bug
@alexander-soare
Thanks a lot for your comprehensive answer. However, I did attempt to do this with create_feature_extractor
as you've done, but the resulting size from your example is torch.Size([1, 197, 768])
, but I would expect it to be torch.Size([1, 768])
, since the final layer that feeds into the final prediction MLP (heads
) as in_features=768
.
Are you able to shed some light on where the 197 is coming from, and how to get a shape of (1, 768)?
Thanks
@alexander-soare Thanks for dedicating the time to do a deep dive. I understand that all the preprocessing that happens on forward() outside of a module is the problem here. Perhaps it's worth adopting on the future the approach of structuring the top-level models as a series of submodules. Actually the majority of models behave like this already. Thoughts?
@DavidTorpey if you look at the output of get_graph_node_names
you will find the last 3 nodes are 'encoder.ln', 'getitem_5', 'heads.head']
. You want the getitem_5
node which does this https://github.com/pytorch/vision/blob/3925946f994a4b779ab9286654e7011bb175a70a/torchvision/models/vision_transformer.py#L270. It's selecting the first of the 197 in your torch.Size([1, 197, 768])
(ie the "class token").
@datumbox before working on FX I was trying to structure things that way in some timm models and found myself making weird modules that only made sense in a particular context. For instance, in the VIT example you would end up wrapping the final slice operation in a module... I also tried the approach of absorbing things into modules (like absorbing the slice operation into the encoder module). Sometimes this works, sometimes you find yourself changing the module name and docstring to try to make the semantics appropriate to the new inputs/outputs of the module. A lot of the time, the detour just doesn't feel worth it.
If the FX solution didn't exist, I'd probably be in agreement that that's the way to go. But with FX, problem solved no?
@alexander-soare Thank you! Makes sense
🐛 Describe the bug
Hi
It’s easy enough to obtain output features from the CNNs in torchvision.models by doing this:
However, when I attempt to do this with torchvision.models.vit_b_16:
I get the following error:
Any help would be greatly appreciated.
Versions
Torch version: 1.11.0+cu102 Torchvision version: 0.12.0+cu102
cc @datumbox