Closed cyiheng closed 3 years ago
Hi! Thank you for your interest. First of all, this is a simple implementation of linear FIN that you can use as reference.
class LinearFIN(nn.Module):
def __init__(self, dim):
super().__init__()
self.instance_norm = nn.InstanceNorm2d(dim, affine=False)
self.a_gamma = nn.Parameter(torch.zeros(dim))
self.b_gamma = nn.Parameter(torch.ones(dim))
self.a_beta = nn.Parameter(torch.zeros(dim))
self.b_beta = nn.Parameter(torch.zeros(dim))
def forward(self, x, phi):
gamma = self.a_gamma.unsqueeze(0) * phi.unsqueeze(-1) + self.b_gamma.unsqueeze(0)
beta = self.a_beta.unsqueeze(0) * phi.unsqueeze(-1) + self.b_beta.unsqueeze(0)
return self.instance_norm(x) * gamma.unsqueeze(-1).unsqueeze(-1) + beta.unsqueeze(-1).unsqueeze(-1)
As regards the modifications to the rest of the code, you should be mostly fine if you just replace the cos_phi
and sin_phi
with a single phi
. The only point where you should be more cautious is in the phiNet_A
angle estimation, which should be replaced by a point estimation on a linear manifold. Hence, you should avoid extracting sine and cosine of the estimated angle and you should use directly the predicted value.
Hi !
First of all thank you for your work, I was waiting your code since I read your paper !
I was wondering if you can give some advice to modify your code from a cyclic function of the FIN layer to a Linear one ?
I actually try to only replace every single cos_phi / sin_phi to a simple phi, but I'm not sure that will be enough. Maybe I will miss some major points by only changing these.
Thank you again !