lucidrains / memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
MIT License
620 stars 46 forks source link

Arguments to reproduce the models from the original paper? #4

Closed manestay closed 2 years ago

manestay commented 2 years ago

Hi lucidrains,

This looks like excellent work! I have gone through the original paper and your repo, and am now trying to reproduce the model from the paper as closely as possible. Of course, the modifications you made such as hybrid attention instead of sigmoid gate are fine.

Specifically, I would like to be able to try some of the variations in Table 4: image

Suppose I'm interested in the 4th to last row with Context 512 Memory 8192 XL cache 512. Can you help me the model arguments to do that? Here is my initial attempt, with reference to Section 4.2:

model = MemorizingTransformer(
    num_tokens = 32000, # vocab 32k
    dim = 1024, 
    depth = 12,
    memorizing_layers = 9,
    max_knn_memories = 8192, # Memory column
    num_retrieved_memories = 32,
    clear_memories_on_sos_token_id = 1,
    xl_memory_layers = (6, 7, 8, 9),  # not sure about this?
    xl_max_memories = 512, # XL cache column
    shift_knn_memories_down = 1, 
    shift_xl_memories_down = 1,
    # which argument corresponds to Context column?
).cuda()

A second question is what are the model arguments to reproduce to first row of Table 4, with no memory nor XL cache? Thanks in advance.

manestay commented 2 years ago

Closing this, since I think this repo rightly adds a lot of improvements over the original paper, which means reproducing isn't worthwhile.