PeterSH6 / paper-notes

My paper reading notes
1 stars 0 forks source link

ICLR '21 | Temporal Graph Networks For Deep Learnig On Dynamic Graphs #20

Closed PeterSH6 closed 2 years ago

PeterSH6 commented 2 years ago
PeterSH6 commented 2 years ago

Training

截屏2022-04-28 下午11 24 54

PeterSH6 commented 2 years ago

TGN Source Code

Training Self Supervised

  1. update first: update memory for all (only negative in last batches) nodes with msg stored in previous batches
  2. tgn.compute_edge_probabilities
  3. tgn.compute_temporal_embeddings
  4. emb_module.compute_embedding
  5. get embeddings and update memory (only for positive nodes and delete their msg) and save raw msg
  6. return embeddings and compute probabilities and get loss and backward

TGN

def compute_edge_probabilities(self, src, dst, neg, edge_time, edge_idx, num_neigh=20):

def compute_temporal_embeddings(self, src, dst, neg, edge_time, edge_idx, num_neigh=20):
    """
    1. get_updated_mem (updated mem & last updated ts)
    2. compute time diffs between time of the mem and edge(event) time
    3. compute embeddings using embedding module
    4. update the (pos nodes') memory based on the msgs and clear the msg after udpate
    5. get current msg and save the msgs in raw msg stores for next batch
    6. return the embeddings of src, dst & neg
    """

def update_memory(self, nodes, messages):
    """
    update the mem & no return
    """
def get_updated_memory(self, nodes, messages):
    """
    1. aggregate msg of the same nodes use memory aggregater
    2. compute msg using msg_func
    3. call memory updater to get udpated memory (updated & last updated)
    """

Memory

"""
properties:
1. n_nodes
2. memory_dim
3. input_dim
4. msg_dim
5. device
6. memory
7. last_update (last update的timestamp)
8. msg (defaultdict)

每个epoch前init memory,初始化为0
function:
1. init
2. store_raw_msg (直接extend msg)
3. get & set mem
4. get last update
5. detach mem
6. clear msg
"""
class Memory(nn.Module):

Message Aggregator

def aggregate(self, node_ids, messages):
    """
    1. unique node_ids
    2. aggregate msg base on the unique node_ids
    """

Mem Updater

"""
GRU & RNN
1. input_size = msg_dim
2. hidden_size = mem_dim
"""
class SequenceMemoryUpdater(MemoryUpdater):
    def update_memory(self, node_ids, msg, timestamps):
    """
    直接set到mem
    """

    def get_updated_memory(self, node_ids, msg, timestamps):
    """
    just call GRU or RNN to get the udpate
    返回updated mem和last update
    """

Embedding

class GraphEmbedding(EmbeddingModule):
    def compute_embedding(self, mem, src, ts, n_layers, num_neighb):
    """
    recursive impl:
    n_layers >= 0; n_layers = 0 return source_node_features (最外层)
    1. source_node_features = mem[source] + source_node_features
    2. get neighbors of the current layer.
    3. recursively call self.compute_embeddings(n_layers-1) and get neighbor_embeddings (if layer==0, return (node_emb + time_emb))
    4. encode time using delta time(time diff)
    5. aggregate -> K, V and Q -> using attention model
    """

"""
graph attention embedding:
- a list of TemporalAttentionLayer
"""

class GraphAttentionEmbedding(GraphEmbedding):
    def aggregate():
    """
    get the current layer and call forward of attention model
    """

class TemporalAttentionLayer(torch.nn.Module):

"""
Temporal attention layer. Return the temporal embedding of a node given the node itself, its neighbors and the edge timestamps.

Properties:
1. n_head
2. feat_dim
3. time_dim
4. query_dim = feat_dim + time_dim
5. key_dim (==v_dim) = n_neighbors_features + time_dim + n_edge_features
6. merger
7. multi_head_attention
"""

    def forward():
    """
    it use key_padding_mask maybe because different nodes don't always have enough neighbors and some needs padding. Use this mask to ignore padding

    QUERY: cat(node_features, time_feat)

    KEY & VALUE: cat(node_features, edge_features, time_feat)

    1. compute attention
    2. Source nodes with no neighbors have an all zero attention output. The attention output is then added or concatenated to the original source node features and then fed into an MLP. This means that an all zero vector is not used.
    3. merger and get attn_output 
        1. h = linear(cat(attn_output, node_features))
        2. = linear(h)
    """

Sample

Neighbor Finder

在Embedding Module中会被调用get_temporal_neighbor()返回neighorsedge_indexs, edge_times来进行计算

class NeighborFinder:
    def __init__()
    """
    Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
    sort the list based on timestamp
    """

    def search_before(self, src_idx, cut_time):
    """
    RETURN: all the events happening before cut_time for user src_idx in the overall interaction graph. The returned events are sorted by time. 3 lists: neighbors, edge_idxs, timestamps
    1. use np.searchsorted(binary search) to find the index before cut time
    2. return the lists based on the index
    """

    def get_temporal_neighbor(self, source_nodes, cut_times, n_neighbors=20):
    """
    2 types of sample: most recent & uniform
    If a src_node's neighbors' num is lower than n_neighbors, the rest timestamps is zero
    The return lists are sorted based on the timestamp (ascending timestamps)
    """

RandEdgeSampler

For negative sampling

class RandEdgeSampler(object):
    def sample(self, size):
    """
    return two random generated lists:
    src_list & dst_list
    And the dst_list are then used as negative batches.
    """