Closed jasperzhong closed 2 years ago
看上去就是 #278 多了一个memory module. 就是每个node多了一个memory vector用来remember收过的messages. 这样来保证一个long term memory.
首先有两种event:
这个memory module包括
剩下部分和#278 感觉差不多... 也是根据neighbour nodes算embedding,不同的是这里会考虑memory,就是把node feature + memory即可. 其他部分一样. 也有functional time encoding like #278
赢了!
实际训练的时候,每个iteration的输入是一个batch of interactions. 上面的流程有一个问题,memory module里面的参数要如何更新呢?他们在进行current batch of interactions计算前,先用previous batch of interactions(存放在raw message store)更新memory,然后用updated memory参与current batch of interactions node embeddings计算. 这样,梯度就可以反向传播到memory module,从而更新其参数. 之所以使用previous batch of interactions,是因为避免在做预测(比如link prediction)的时候实现看到要预测的edges.
值得注意的是,在做current batch of interactions计算的时候,其所使用的memory全都是基于previous batch of interactions计算的,也就是说,batch内部的后面的interactions没有利用前面的interactions,它们都是基于previous batch of interactions. 这是一个trade-off between speed and update granularity. In practice, they use a batch size of 200.
Input: a batch of events (src, dst, timestamp, edge_features, label)
Graph representation: adjacency list (for each node, its involving events are sorted by timestamps)
Neighbor sampling: uniform or most recent
https://arxiv.org/pdf/2006.10637v2
代码:https://github.com/twitter-research/tgn