facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
931 stars 67 forks source link

Question regarding scatter_reduce #43

Closed AshStuff closed 2 months ago

AshStuff commented 2 months ago

I see that there is a scatter_reduce function at https://github.com/dbolya/tomesd/blob/main/tomesd/merge.py#L105 which sums dst token and src token. However by default, the scatter_reduce function from pytorch https://pytorch.org/docs/stable/generated/torch.scatter_reduce.html has include_self=True. This tries to add the first r slices of dst as well. Do you think we need to set include_self=False?