Closed isaaccorley closed 3 months ago
@isaaccorley already on the radar but haven't had a chance to come up with a design yet. See #617. I was going to rename the title of that but I'll keep this one open instead and close 617.
One complication is that new vision transformer models have been coming in at a rapid rate. Any solution has to work for all of them, whether they have spatial features or flattened, and whether the user wants them to be unflatten if they are flat and there is enough info to unflatten, etc etc.
I would also like to do it without significant modifications and new conditionals any forward fns, but just adding extra metadata in a feature_info
attribute as it becomes specific to each model. This is what I did for CNNs here and wrote FeatureExtractor modules to work with the metadata.
There is also the consideration of tokens, to include them where possible (if they exist) or not #655
And finally, the torchvision team let me know they're working on some new funcitonality using torch.fx that would allow writing more general feature extraction. I was thinking of exploring that as an possible approach...
@rwightman After reviewing some of the other transformer variant code I understand what you mean. I have some additional bandwidth to help out if needed when you want to tackle this. I think in the meantime a FeatureExtractor module but for vision transformers would work nicely but torch.fx also seems like it has the ability to solve this exact problem.
@isaaccorley some help could be useful here, I need to ask torchvision team and figure out what the fx solution might look like, whether it has a chance of working well with torchscript, and then make decision on putting in the effort there or making a vision transformer / vision mlp specific feature extraction module
The same problem present in other models, such as 'mobilenetv3_large_100'. The global_pool operator is embedded in the 'forward_features' function which is inconsistent with the tutorial in [https://rwightman.github.io/pytorch-image-models/feature_extraction/] @rwightman
@JarvisKevin mnv3 is a completely different situation, if you read the comments in that model definition, that architecture, by design does not match other convnets due to its 'efficient head' where global avg pooling is moved in front of the penultimate feature layer, the current choice for foward_features vs features was made to balance simplicity and compatibility with as many of the feature extraction use cases a possible.
Re ViT... I have in-progress refactoring / feature changes for vit/mlp models in the works. It's pretty significant and relies on some helpers that won't be in torchvision until 1.10 pytorch release so I'm sitting on them a while and making sure they are the right approach..
Hello, I find un_pooled part before x = self.pos_drop(x + self.pos_embed) in some other codes, while the aboves after them. I wonder which method works better.
I believe this issue should be considered resolved now that forward_features has been implemented. (https://github.com/huggingface/pytorch-image-models/issues/1029)
It seems that
VisionTransformer
doesn't support feature extraction of all outputs in theforward_features
method. Only returning of the cls token or [cls_token, distillation_token] is available timm/models/vision_transformer.py#L291-L304. This functionality seems particularly useful similar to how pretrained ResNets features are commonly used for downstream tasks.However this is available for other models e.g.
I implemented something similar for a side project here that required all ViT outputs for some a downstream segmentation task. I simply override the method in my example, but I assume some attribute could be added to
VisionTransformer
to allow for returning 'unpooled' output. Maybe something like this: