lucidrains / memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
MIT License
620 stars 46 forks source link

KNNMemory add() does not appear to update self.knns #10

Closed vyaivo closed 1 year ago

vyaivo commented 1 year ago

Thanks for the nice implementation. I've adapted this code for my own use, so I don't have the whole stack that would reproduce this bug. However, you can check for yourself.

The following code ought to update the KNN objects in the KNNMemory class:

@delayed
def knn_add(knn, key, db_offset):
    knn.add(key, ids = knn_insert_ids + db_offset)

Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))

[link to that code here]

However, even after repeated calls to add to the memory, calling KNNMemory.search() results in empty values. If you view self.knns at this point, self.is_trained remains False.

When I modify the code as follows, this fixes the issue.

@delayed
def knn_add(knn, key, db_offset):
    knn.add(key, ids = knn_insert_ids + db_offset)
    return knn

updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
self.knns = updated_knns

This will allow searches to return actual values.