chinhsuanwu / coatnet-pytorch

A PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes"
https://arxiv.org/abs/2106.04803
MIT License
370 stars 67 forks source link

About the Wi-j #14

Open RioLLee opened 2 years ago

RioLLee commented 2 years ago

Thanks for your sharing. We want to confirm that the relative_coords is learnable parameters or constant in CoatNet?

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)

        relative_coords = coords[:, :, None] - coords[:, None, :]
        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)