facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
931 stars 67 forks source link

How can I use ToMe under pytorch 1.10.1 #10

Closed Nightwatch-Fox11 closed 1 year ago

Nightwatch-Fox11 commented 1 year ago

Hello, Thanks for this amazing work!

I know that I need Pytorch >= 1.12.0 when using ToMe because of the scatter_reduce function. However, for some reason, I can't use Pytorch > 1.10.1 in my development environment. I hope to apply ToMe to my vit-like model, so is it possible for me to use ToMe under pytorch 1.10.1?

Moreover, I'm wondering is there any elegant way that I can restore the original order of the tokens which disrupted by ToMe module? For example, how can I use the merged tokens and source matrix to fill the missing tokens?

Looking forward to your reply :)

Best

dbolya commented 1 year ago

is it possible for me to use ToMe under pytorch 1.10.1

Yes, I believe scatter_add should be available. Its behavior is the same as scatter_reduce with reduce="sum" and I think by default the code uses sum for everything. Try replacing scatter_reduce with scatter_add and removing the reduce argument and see if that works.

For example, how can I use the merged tokens and source matrix to fill the missing tokens?

Just matrix multiply: source.transpose(-1, -2) @ final_tokens.

Source should be $B \times N_o \times N_i$ where $B$ is batch size $N_o$ is output tokens at the end of the network and $N_i$ is input tokens at the start.
Final tokens should be $B \times N_o \times C$ where $C$ is the number of channels.
After the multiplication you should have $B \times N_i \times C$, filling the missing tokens.

I think those are the right dimensions, but I might have messed up the transpose.

Nightwatch-Fox11 commented 1 year ago

Thanks for your reply, this really helps me a lot. I'll try it later.