THUDM / GraphMAE

GraphMAE: Self-Supervised Masked Graph Autoencoders in KDD'22
478 stars 75 forks source link

self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim)) #34

Closed songshuhan closed 1 year ago

songshuhan commented 1 year ago

你好,我在看源码的时候看到 self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim)) .... out_x[token_nodes] += self.enc_mask_token

想咨询一下为什么所有的token_nodes在重建的时候全是共享的呢?这里self.enc_mask_token是不是应该是nn.Parameter(torch.zeros(num_token_nodes, in_dim))??

songshuhan commented 1 year ago

想咨询一下共享向量的作用是什么?

songshuhan commented 1 year ago

是不是应该给每一个token_node都有一个可学习的向量?

songshuhan commented 1 year ago

补充一个小问题,如果不加这个nn.Parameter(torch.zeros(1, in_dim)),对结果会有什么影响么?因为感觉在decoder之前已经mask掉了,就是有点不太清晰out_x[token_nodes] += self.enc_mask_token的意义,希望可以答疑,感谢

THINK2TRY commented 1 year ago

@songshuhan mask_token [MASK] 的作用是指示对应的节点是被 mask 掉的,需要模型在训练的时候预测这个点的特征,所以不需要每个点有一个独特的 learnable vector

songshuhan commented 1 year ago

感谢作者的回复!我git clone了一下代码运行了一下,原始的代码在cora节点分类上的准确率在82.5%左右,然后我不加self.enc_mask_token,在cora上的准确率是81.5%左右,看起来加上确实会好一点。

不过每个节点都是根据encoder从邻居节点聚合信息来更新自己的embedding,而且我们也知道token_nodes的索引,好像并不需要一个“指示”向量加入才可以训练更新到被mask的节点,所以我感觉还是没太理解这个意义是什么?

THINK2TRY commented 1 year ago

因为 SSL 的 objective 是 重建 masked node features, mask token 是用来指示 这个节点的特征是 被 mask 了。类似的做法在NLP 和 CV 中都存在。作为一个 learnable 的 向量可能也能缓解本身 mask 节点特征后整个图的特征分布、帮助模型更好的学习重建等。

songshuhan commented 1 year ago

感谢作者的回复,受益匪浅