Closed egistific closed 4 years ago
Hi, using some other input sizes generates an error, so I shifted the following two lines in the class AAConv2d to the forward section after line 70:
self.key_rel_h = nn.Parameter(dk*-0.5 + torch.randn(dk//nh, 2H-1)) self.key_rel_w = nn.Parameter(dk*-0.5 + torch.randn(dk//nh, 2W-1))
May I know if this is correct? Thanks!
Hi the code works fine actually. I think I made a mistake somewhere.
Hi, using some other input sizes generates an error, so I shifted the following two lines in the class AAConv2d to the forward section after line 70:
self.key_rel_h = nn.Parameter(dk*-0.5 + torch.randn(dk//nh, 2H-1)) self.key_rel_w = nn.Parameter(dk*-0.5 + torch.randn(dk//nh, 2W-1))
May I know if this is correct? Thanks!