Closed Nightwatch-Fox11 closed 2 years ago
is it possible for me to use ToMe under pytorch 1.10.1
Yes, I believe scatter_add
should be available. Its behavior is the same as scatter_reduce
with reduce="sum"
and I think by default the code uses sum for everything. Try replacing scatter_reduce
with scatter_add
and removing the reduce
argument and see if that works.
For example, how can I use the merged tokens and source matrix to fill the missing tokens?
Just matrix multiply: source.transpose(-1, -2) @ final_tokens
.
Source should be $B \times N_o \times N_i$ where $B$ is batch size $N_o$ is output tokens at the end of the network and $N_i$ is input tokens at the start.
Final tokens should be $B \times N_o \times C$ where $C$ is the number of channels.
After the multiplication you should have $B \times N_i \times C$, filling the missing tokens.
I think those are the right dimensions, but I might have messed up the transpose.
Thanks for your reply, this really helps me a lot. I'll try it later.
Hello, Thanks for this amazing work!
I know that I need Pytorch >= 1.12.0 when using ToMe because of the scatter_reduce function. However, for some reason, I can't use Pytorch > 1.10.1 in my development environment. I hope to apply ToMe to my vit-like model, so is it possible for me to use ToMe under pytorch 1.10.1?
Moreover, I'm wondering is there any elegant way that I can restore the original order of the tokens which disrupted by ToMe module? For example, how can I use the merged tokens and source matrix to fill the missing tokens?
Looking forward to your reply :)
Best