Closed swarajnanda2021 closed 9 months ago
In your code for Attention class, the relative position matrix calculation can be sped up in the following manner:
Common to both: y, x = torch.meshgrid(torch.arange(ih), torch.arange(iw), indexing='ij') y_flat, x_flat = y.flatten(), x.flatten()
You implementation: 0.01103353500366211 seconds for a 3X3 matrix rel_y = y_flat.repeat_interleave(nn).view(nn, nn) - y_flat.repeat(nn).view(nn, nn) rel_x = x_flat.repeat_interleave(nn).view(nn, nn) - x_flat.repeat(nn).view(nn, nn) rel_pos = (rel_y + ih - 1) (2 iw - 1) + (rel_x + iw - 1) # Unique index calculation
Suggestion: 0.0007836818695068359 seconds for 3X3 matrix rel_y = y_flat.flip(dims=[0]).repeat(nn, 1) - y_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn) rel_x = x_flat.flip(dims=[0]).repeat(nn, 1) - x_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn) rel_pos = (rel_y + ih - 1) (2 iw - 1) + (rel_x + iw - 1) # Unique index calculation
In your code for Attention class, the relative position matrix calculation can be sped up in the following manner:
Common to both: y, x = torch.meshgrid(torch.arange(ih), torch.arange(iw), indexing='ij') y_flat, x_flat = y.flatten(), x.flatten()
You implementation: 0.01103353500366211 seconds for a 3X3 matrix rel_y = y_flat.repeat_interleave(nn).view(nn, nn) - y_flat.repeat(nn).view(nn, nn) rel_x = x_flat.repeat_interleave(nn).view(nn, nn) - x_flat.repeat(nn).view(nn, nn) rel_pos = (rel_y + ih - 1) (2 iw - 1) + (rel_x + iw - 1) # Unique index calculation
Suggestion: 0.0007836818695068359 seconds for 3X3 matrix rel_y = y_flat.flip(dims=[0]).repeat(nn, 1) - y_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn) rel_x = x_flat.flip(dims=[0]).repeat(nn, 1) - x_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn) rel_pos = (rel_y + ih - 1) (2 iw - 1) + (rel_x + iw - 1) # Unique index calculation