Closed PeterSH6 closed 2 years ago
def compute_edge_probabilities(self, src, dst, neg, edge_time, edge_idx, num_neigh=20):
def compute_temporal_embeddings(self, src, dst, neg, edge_time, edge_idx, num_neigh=20):
"""
1. get_updated_mem (updated mem & last updated ts)
2. compute time diffs between time of the mem and edge(event) time
3. compute embeddings using embedding module
4. update the (pos nodes') memory based on the msgs and clear the msg after udpate
5. get current msg and save the msgs in raw msg stores for next batch
6. return the embeddings of src, dst & neg
"""
def update_memory(self, nodes, messages):
"""
update the mem & no return
"""
def get_updated_memory(self, nodes, messages):
"""
1. aggregate msg of the same nodes use memory aggregater
2. compute msg using msg_func
3. call memory updater to get udpated memory (updated & last updated)
"""
"""
properties:
1. n_nodes
2. memory_dim
3. input_dim
4. msg_dim
5. device
6. memory
7. last_update (last update的timestamp)
8. msg (defaultdict)
每个epoch前init memory,初始化为0
function:
1. init
2. store_raw_msg (直接extend msg)
3. get & set mem
4. get last update
5. detach mem
6. clear msg
"""
class Memory(nn.Module):
def aggregate(self, node_ids, messages):
"""
1. unique node_ids
2. aggregate msg base on the unique node_ids
"""
"""
GRU & RNN
1. input_size = msg_dim
2. hidden_size = mem_dim
"""
class SequenceMemoryUpdater(MemoryUpdater):
def update_memory(self, node_ids, msg, timestamps):
"""
直接set到mem
"""
def get_updated_memory(self, node_ids, msg, timestamps):
"""
just call GRU or RNN to get the udpate
返回updated mem和last update
"""
class GraphEmbedding(EmbeddingModule):
def compute_embedding(self, mem, src, ts, n_layers, num_neighb):
"""
recursive impl:
n_layers >= 0; n_layers = 0 return source_node_features (最外层)
1. source_node_features = mem[source] + source_node_features
2. get neighbors of the current layer.
3. recursively call self.compute_embeddings(n_layers-1) and get neighbor_embeddings (if layer==0, return (node_emb + time_emb))
4. encode time using delta time(time diff)
5. aggregate -> K, V and Q -> using attention model
"""
"""
graph attention embedding:
- a list of TemporalAttentionLayer
"""
class GraphAttentionEmbedding(GraphEmbedding):
def aggregate():
"""
get the current layer and call forward of attention model
"""
class TemporalAttentionLayer(torch.nn.Module):
"""
Temporal attention layer. Return the temporal embedding of a node given the node itself, its neighbors and the edge timestamps.
Properties:
1. n_head
2. feat_dim
3. time_dim
4. query_dim = feat_dim + time_dim
5. key_dim (==v_dim) = n_neighbors_features + time_dim + n_edge_features
6. merger
7. multi_head_attention
"""
def forward():
"""
it use key_padding_mask maybe because different nodes don't always have enough neighbors and some needs padding. Use this mask to ignore padding
QUERY: cat(node_features, time_feat)
KEY & VALUE: cat(node_features, edge_features, time_feat)
1. compute attention
2. Source nodes with no neighbors have an all zero attention output. The attention output is then added or concatenated to the original source node features and then fed into an MLP. This means that an all zero vector is not used.
3. merger and get attn_output
1. h = linear(cat(attn_output, node_features))
2. = linear(h)
"""
在Embedding Module中会被调用get_temporal_neighbor()
返回neighors
,edge_indexs
, edge_times
来进行计算
class NeighborFinder:
def __init__()
"""
Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
sort the list based on timestamp
"""
def search_before(self, src_idx, cut_time):
"""
RETURN: all the events happening before cut_time for user src_idx in the overall interaction graph. The returned events are sorted by time. 3 lists: neighbors, edge_idxs, timestamps
1. use np.searchsorted(binary search) to find the index before cut time
2. return the lists based on the index
"""
def get_temporal_neighbor(self, source_nodes, cut_times, n_neighbors=20):
"""
2 types of sample: most recent & uniform
If a src_node's neighbors' num is lower than n_neighbors, the rest timestamps is zero
The return lists are sorted based on the timestamp (ascending timestamps)
"""
For negative sampling
class RandEdgeSampler(object):
def sample(self, size):
"""
return two random generated lists:
src_list & dst_list
And the dst_list are then used as negative batches.
"""
CTDG: Continous-time dynamic graphs; Timed lists of events
two types of events:
Core Moudels
Memory Module
新加入的结点memory初始化为zero vector
$S_i(t)$是对于节点i在时间t的memory state
当出现一个event之后memory state会更新
Message Function
Msg是用来udpate memory state的
Function:
Message Aggregator
Batch processing events -> 多个不同时间的event(msg)针对同一个node
$\overline{m_i}(t) = agg(m_i(t_1), ..., m_i(t_b))$
$agg$可以是RNN或者attention,也可以是最简单的most recent msg和mean msg
Mem Updater
把$m_i(t)$(message)作为输入,然后node i之前memory作为hidden state,送到LSTM(或GRU),更新hidden state,即得到新的memory.
$s_i(t) = mem(\overline{m_i}(t), s_i(t^-))$
Embeddings
h函数有多种实现方式
Temporal Graph Sum