Open Void-JackLee opened 1 year ago
according to paper,because they want to memory component get grads,so they need to update memory before calculate embedding,but they can't use n-th batch to update memory due to memory leak problem, instead they use (n-1)-th batch. Only set memory_update_at_start=True, can memory component get grads and params get updated
So here comes a problem that update_memory(positives, self.memory.messages) was updating positive nodes in this batch, and updated messages was from last batch. I don't understand why the code is doing this, maybe it's a bug?
the proper order is
Hi @shadow150519 , after read your reply, I got some problems to discuss. First is that, since we need to update memory in the order 3th that you list, why not directly update the memory in the order 1th and get a copy at the same time ? I guess a possible answer is that order 1th need to get updated memory of both positive and negative nodes, and order 3th only need to update memory of positive nodes. Maybe a possible optimize solution is to combine order 3th to 1th, calculate both positive and negative nodes' memory, get a copy and update only positive nodes' memory.
Thanks for your reply : )
Hi @emalgorithm, I got some problems when reading your codes.
When
memory_update_at_start=True
, the msg_agg and msg_func will calc twice, before thecompute_embedding
and aftercompute_embedding
. Before thecompute_embedding
, theget_updated_memory
function will calc all nodes' memory. After thecompute_embedding
,update_memory
function will calc positive nodes.The code annotation here was "Persist the updates to the memory only for sources and destinations (since now we have new messages for them)", but actually the message in this batch was update after the memory update,
update_memory
function was updating memory from the message in last batch. So here comes a problem thatupdate_memory(positives, self.memory.messages)
was updating positive nodes in this batch, and updated messages was from last batch. I don't understand why the code is doing this, maybe it's a bug?I think here needs to update all nodes' memory (or record last batch's positive nodes), or update memory in
get_updated_memory
function directly (replace it toupdate_memory
).