Open jamie212 opened 1 month ago
Your work is excellent! I read your paper and noticed that you mentioned using SGA. However, I'm having some trouble understanding parts of the code. Could you tell me which parts of your model_our are the SGA modules?
Thank you for your interest in my work. The following two classes implement the SGA module.
class Gconv(nn.Module): def init(self, in_ch, out_ch): super(Gconv, self).init() self.src_fc = nn.Linear(in_ch, out_ch) self.msg_fc = nn.Linear(in_ch, out_ch) self.bn = nn.BatchNorm1d(out_ch, affine=True, track_running_stats=True)
def forward(self, A, source, message):
src = self.src_fc(source)
msg = self.msg_fc(message)
gen = torch.bmm(A, F.leaky_relu(msg, negative_slope=0.2)) + F.leaky_relu(src, negative_slope=0.2)
return self.bn(gen.permute(0, 2, 1)).permute(0, 2, 1)
class GNN(nn.Module): def init(self, channel): super(GNN, self).init() self.channel = channel self.gcn_cross = Gconv(in_ch=channel, out_ch=channel) self.gcn_self = Gconv(in_ch=channel, out_ch=channel)
@staticmethod
def build_graph(src, tgt):
"""
src -> (b,wh,c)
tgt -> (b,wh,c)
"""
with torch.no_grad():
graph = src.bmm(tgt.permute(0, 2, 1))
graph = F.softmax(graph, dim=-1)
graph = F.normalize(graph, p=1, dim=-2)
return graph
def forward(self, skt, ref):
b, c, h, w = skt.size()
skt = skt.view(b, c, h * w).permute(0, 2, 1)
ref = ref.view(b, c, h * w).permute(0, 2, 1)
sr = self.build_graph(src=skt, tgt=ref)
gen = self.gcn_cross(A=sr, source=skt, message=ref) + skt
gg = self.build_graph(src=gen, tgt=gen)
ggen = self.gcn_self(A=gg, source=gen, message=gen) + gen
return ggen
Your work is excellent! I read your paper and noticed that you mentioned using SGA. However, I'm having some trouble understanding parts of the code. Could you tell me which parts of your model_our are the SGA modules?