ma-xu / FCViT

A Close Look at Spatial Modeling: From Attention to Convolution
Apache License 2.0
91 stars 5 forks source link

Could you provide alternative code not using einops rearrange? #2

Closed nikbobo closed 1 year ago

nikbobo commented 1 year ago
    def forward(self,x):
        if self.weight_gc:
            b,c,w,h = x.size()
            x = rearrange(x,"b c x y -> b c (x y)")
            gap = x.mean(dim=-1, keepdim=True)
            # q, g = map(lambda t: rearrange(t, 'b (h d) n -> b h d n', h = self.head), [x,gap])  #[b,head, hdim, n]
            q = rearrange(x, 'b (h d) n -> b h d n', h = self.head)
            g = rearrange(gap, 'b (h d) n -> b h d n', h = self.head)
            sim = einsum('bhdi,bhjd->bhij', q, g.transpose(-1, -2)).squeeze(dim=-1) * self.scale  #[b,head, w*h]
            std, mean = torch.std_mean(sim, dim=[1,2], keepdim=True)
            sim = (sim-mean)/(std+self.epsilon)
            sim = sim * self.rescale_weight.unsqueeze(dim=0).unsqueeze(dim=-1) + self.rescale_bias.unsqueeze(dim=0).unsqueeze(dim=-1)
            sim = sim.reshape(b,self.head,1, w, h) # [b, head, 1, w, h]
            gc = self._get_gc(gap.squeeze(dim=-1)).reshape(b,self.head,-1).unsqueeze(dim=-1).unsqueeze(dim=-1)  # [b, head, hdim, 1, 1]
            gc = rearrange(sim*gc, "b h d x y -> b (h d) x y")  # [b, head, hdim, w, h] - > [b,c,w,h]

Because of I wanted to convert it to TorchScript model, pytorch says rearrange is not support in TorchScript. Could you provide alternative code not using einops rearrange? I try to fix it, but I have poor coding ablility.