chinhsuanwu / coatnet-pytorch

A PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes"
https://arxiv.org/abs/2106.04803
MIT License
370 stars 67 forks source link

Perhaps a nicer way to estimate the relative position #18

Closed swarajnanda2021 closed 9 months ago

swarajnanda2021 commented 9 months ago

In your code for Attention class, the relative position matrix calculation can be sped up in the following manner:

Common to both: y, x = torch.meshgrid(torch.arange(ih), torch.arange(iw), indexing='ij') y_flat, x_flat = y.flatten(), x.flatten()

You implementation: 0.01103353500366211 seconds for a 3X3 matrix rel_y = y_flat.repeat_interleave(nn).view(nn, nn) - y_flat.repeat(nn).view(nn, nn) rel_x = x_flat.repeat_interleave(nn).view(nn, nn) - x_flat.repeat(nn).view(nn, nn) rel_pos = (rel_y + ih - 1) (2 iw - 1) + (rel_x + iw - 1) # Unique index calculation

Suggestion: 0.0007836818695068359 seconds for 3X3 matrix rel_y = y_flat.flip(dims=[0]).repeat(nn, 1) - y_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn) rel_x = x_flat.flip(dims=[0]).repeat(nn, 1) - x_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn) rel_pos = (rel_y + ih - 1) (2 iw - 1) + (rel_x + iw - 1) # Unique index calculation