The-AI-Summer / self-attention-cv

Implementation of various self-attention mechanisms focused on computer vision. Ongoing repository.
https://theaisummer.com/
MIT License
1.18k stars 154 forks source link

Regression with attention #9

Closed alemelis closed 3 years ago

alemelis commented 3 years ago

Hello!

thanks for sharing this nice repo :)

I'm trying to use ViT to do regression on images. I'd like to predict 6 floats per image.

My understanding is that I'd need to simply define the network as

vit = ViT(img_dim=128,
               in_channels=3,
               patch_dim=16,
               num_classes=6,
               dim=512)

and during training call

vit(x)

and compute the loss as MSE instead of CE.

The network actually runs but it doesn't seem to converge. Is there something obvious I am missing?

many thanks!

black0017 commented 3 years ago

Hello I dont have practical experience with regression tasks. I recently came up on this article for vanilla transformers and regression https://www.linkedin.com/pulse/how-i-turned-nlp-transformer-time-series-predictor-zimbres-phd/

From a quick glance I see that the author uses a super tiny transformer d_model=4 , d_ff=32 while the original model hasd_model=512, d_ff=2048

this correspond to dim=4 and dim_linear_block=32 in this implementation of vit.

i would try with a smaller model in general and 1 regression value first. Let me know if it helps.