def forward(self,x):
if self.weight_gc:
b,c,w,h = x.size()
x = rearrange(x,"b c x y -> b c (x y)")
gap = x.mean(dim=-1, keepdim=True)
# q, g = map(lambda t: rearrange(t, 'b (h d) n -> b h d n', h = self.head), [x,gap]) #[b,head, hdim, n]
q = rearrange(x, 'b (h d) n -> b h d n', h = self.head)
g = rearrange(gap, 'b (h d) n -> b h d n', h = self.head)
sim = einsum('bhdi,bhjd->bhij', q, g.transpose(-1, -2)).squeeze(dim=-1) * self.scale #[b,head, w*h]
std, mean = torch.std_mean(sim, dim=[1,2], keepdim=True)
sim = (sim-mean)/(std+self.epsilon)
sim = sim * self.rescale_weight.unsqueeze(dim=0).unsqueeze(dim=-1) + self.rescale_bias.unsqueeze(dim=0).unsqueeze(dim=-1)
sim = sim.reshape(b,self.head,1, w, h) # [b, head, 1, w, h]
gc = self._get_gc(gap.squeeze(dim=-1)).reshape(b,self.head,-1).unsqueeze(dim=-1).unsqueeze(dim=-1) # [b, head, hdim, 1, 1]
gc = rearrange(sim*gc, "b h d x y -> b (h d) x y") # [b, head, hdim, w, h] - > [b,c,w,h]
Because of I wanted to convert it to TorchScript model, pytorch says rearrange is not support in TorchScript.
Could you provide alternative code not using einops rearrange?
I try to fix it, but I have poor coding ablility.
Because of I wanted to convert it to TorchScript model, pytorch says rearrange is not support in TorchScript. Could you provide alternative code not using einops rearrange? I try to fix it, but I have poor coding ablility.