DanielCho-HK / ICME2023-Attention-Aware-Anime-Line-Drawing-Colorization

6 stars 2 forks source link

About SGA #3

Open jamie212 opened 1 month ago

jamie212 commented 1 month ago

Your work is excellent! I read your paper and noticed that you mentioned using SGA. However, I'm having some trouble understanding parts of the code. Could you tell me which parts of your model_our are the SGA modules?

DanielCho-HK commented 1 month ago

Your work is excellent! I read your paper and noticed that you mentioned using SGA. However, I'm having some trouble understanding parts of the code. Could you tell me which parts of your model_our are the SGA modules?

Thank you for your interest in my work. The following two classes implement the SGA module.

class Gconv(nn.Module): def init(self, in_ch, out_ch): super(Gconv, self).init() self.src_fc = nn.Linear(in_ch, out_ch) self.msg_fc = nn.Linear(in_ch, out_ch) self.bn = nn.BatchNorm1d(out_ch, affine=True, track_running_stats=True)

def forward(self, A, source, message):
    src = self.src_fc(source)
    msg = self.msg_fc(message)

    gen = torch.bmm(A, F.leaky_relu(msg, negative_slope=0.2)) + F.leaky_relu(src, negative_slope=0.2)

    return self.bn(gen.permute(0, 2, 1)).permute(0, 2, 1)

class GNN(nn.Module): def init(self, channel): super(GNN, self).init() self.channel = channel self.gcn_cross = Gconv(in_ch=channel, out_ch=channel) self.gcn_self = Gconv(in_ch=channel, out_ch=channel)

@staticmethod
def build_graph(src, tgt):
    """
    src -> (b,wh,c)
    tgt -> (b,wh,c)
    """
    with torch.no_grad():
        graph = src.bmm(tgt.permute(0, 2, 1))
        graph = F.softmax(graph, dim=-1)
        graph = F.normalize(graph, p=1, dim=-2)

    return graph

def forward(self, skt, ref):
    b, c, h, w = skt.size()
    skt = skt.view(b, c, h * w).permute(0, 2, 1)
    ref = ref.view(b, c, h * w).permute(0, 2, 1)
    sr = self.build_graph(src=skt, tgt=ref)

    gen = self.gcn_cross(A=sr, source=skt, message=ref) + skt

    gg = self.build_graph(src=gen, tgt=gen)

    ggen = self.gcn_self(A=gg, source=gen, message=gen) + gen

    return ggen