lucidrains / memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
MIT License
625 stars 46 forks source link

Any interesting results? #1

Open rom1504 opened 2 years ago

rom1504 commented 2 years ago

Hey! Cool repo. I like all the knn+lm methods Did you do some runs yet? Anything interesting to report?

lucidrains commented 2 years ago

:wave: hello Romain! no not yet, i still need to build out the modular forgetting system

didn't you start your new job?

rom1504 commented 2 years ago

Ok great, I'll follow up on the progress :)

Indeed I started the new job, pretty interesting!

igor0 commented 2 years ago

I hope you don't mind me stalking this project, but I tried this out on enwik8 (https://github.com/igor0/memorizing-transformers-pytorch/commit/d302feee0c3d9655a92c392850c4ec5d86bff77c). I basically just ported the enwik8 training loop from another one of @lucidrains's projects.

The initial finding is that with KNN memories, the training loop is pretty slow, so often (but not always) I'll sample 0% GPU usage. Disabling the KNN memories makes the training loop go much faster (>10x compared to with KNN). So the KNN code may need some optimization, but I don't understand it well enough yet to suggest something constructive.

Edit: Ah, that was with knn_use_gpu=False, I missed that. With knn_use_gpu=True, I seem to get a hang. On the GPU, it's faiss.index_cpu_to_all_gpus(self.index) that's hanging for me, endlessly chewing up CPU cycles. Just FYI.

lucidrains commented 2 years ago

@igor0 ah thanks for trying it! i'll hook up the enwik8 training myself next week and start profiling and see what's going on :) still want to polish up the pluggable memory expiring strategy (account for memory creation time as well as last retrieved time)

igor0 commented 2 years ago

I ended up having two issues with KNN on the GPU. Here are the findings so far.

1. the wheel package faiss-gpu hangs for me on A100 and A10 With the faiss-gpu package installed by pip, I always get a hang in index_cpu_to_all_gpus(). I opened an issue here: https://github.com/kyamagu/faiss-wheels/issues/54. I would guess that the faiss-gpu wheel isn't compatible with CUDA 11, so the A100/A10 GPUs don't work.

Using conda rather than pip to install faiss-gpu seems to work for me.

2. " remove_ids not implemented for this type of index" As far as I can tell, remove_ids is not supported for any GPU indexes. One possible solution may be to simulate a sliding window with two GPU indexes, so that we always completely clear an index, instead of removing entries one-by-one. The fancier expiry will get more complicated and will require some type of manual compaction, at least if you want to run faiss on the GPU.

rom1504 commented 2 years ago

there is little benefit to using faiss gpu

however if knn operations are slow here, it's likely because a flat index is used

igor0 commented 2 years ago

What type of index to use then? One problem is that we don't know the distribution of the keys upfront, and the clustering approaches require that. Furthermore, the distribution of keys changes over time. So, you could keep recomputing the clusters periodically. I'm sure that's doable, but another thing to sort out and tune.

IMO a flat index is a reasonable place to start. And a flat index on a GPU would perform much better than a flat index on a CPU.

If faiss doesn't let us implement a flat index with the ability to replace entries, then we could implement our own sliding window mechanism, or just avoid faiss for now and simply implement the memory directly as a PyTorch tensor. That could be one straightforward solution.

lucidrains commented 2 years ago

Hmm, yeah, maybe this won't be smooth sailing

There is also another library out there called TorchPQ that may fit the bill of running on GPU and have removing of ids. But it is relatively young library still, so prob not without a few rough edges. I'll take a closer look next week, thanks for prematurely trying this out!

lucidrains commented 2 years ago

https://github.com/DeMoriarty/TorchPQ

igor0 commented 2 years ago

FlatContainer in TorchPQ looks promising as a potential flat GPU index (to avoid the challenges with clustering): https://github.com/DeMoriarty/TorchPQ/blob/main/torchpq/container/FlatContainer.py

It seems like FlatContainer::set_data_by_address() can arbitrarily overwrite records in the flat container. That would be more efficient than FlatContainer::remove() because remove() needs to copy a lot of data around. Not sure how much that will matter in the end, but always good to avoid copying when possible.

rom1504 commented 2 years ago

Could you describe how often these operations are done within memorizing transformers:

rom1504 commented 2 years ago

An example of index that works without training (although it's not obvious that's a good property at this point) is IndexHNSW

What I meant by "there is little benefit in using faiss GPU" is that faiss indices are usually very fast (search is done in less than 1ms) on CPU, and it's not faster on GPU. The only time it's better to use GPU is if you need to query with a huge batch of embeddings (let's say 1M)

But the choice of index should be done based on how often you need to search/add/train/remove and how many vectors you have. So if you give more information on that, i can advise

lucidrains commented 2 years ago

@rom1504 thanks for offering your expertise! so basically this paper is adding embeddings at a rate of 512 tokens per training step. To compound on the problem, they are doing separate indexed memories per batch, which is why I have to instantiate the number of faiss indices equal to the batch size. Searching is done also every training step (after the first), with a top k of 32, and removal of embeddings starts after it hits some capacity limit (in the paper, they had 2048, and then scales the memory size up to 16000) 2048 would mean the removing of ids start on the 5th step. So basically high rates of adding, removing, searching.

The author told me what they were doing within google is running each batch on 1 TPU core, and thus able to assign it its own index.

Flat would be fine, but it also negates the paper's main selling point, which is that fetching from approximate knn memories should benefit attention net greatly. Hopefully it isn't the case that it "does not work in practice" due to engineering obstacles

lucidrains commented 2 years ago

even in the worst case, I think the lessons from this paper can be carried away to some other architecture (say if one were to generalize https://github.com/lucidrains/HTM-pytorch) 1. storing l2normed - key / values (cosine sim attention) as memories for stability 2. memories need not be differentiable 3. approximate knn is fine 4. one only needs one or two layers of long term memory fetching at most (placed at the middle of the attention net)

rom1504 commented 2 years ago

ok interesting Removing from an index is usually slow, so I would not remove. Instead I would replace the remove operation by adding removed indices to a mask. (and when doing the search you search with an higher K value and apply the mask, higher K affects minimally the search speed)

And maybe you rebuild the index from scratch every 1000 steps to save on memory if needed.

about add/search, I would start by trying simply using IndexHNSWFlat, I believe it will work well enough. It's a little slow at adding (maybe 10ms for a batch of 512), but search is basically instant (0.1ms)

import faiss
index = faiss.IndexHNSWFlat(dimension, 15, faiss.METRIC_INNER_PRODUCT)
index.add(faiss.rand((512, dimension)))
lucidrains commented 2 years ago

@rom1504 thanks for the suggestion :+1:

on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable

igor0 commented 2 years ago

The problem is that for each training sample, we need to:

So, adds : removals : searches are 1 : 1 : 1, or alternatively you can think of it as searches : replacements being 1 : 1. So, we are doing the removals in order to create space for the add. Masking out the elements doesn't really solve the problem for us because it doesn't open up new space.

One solution is to have two indexes: current and previous. We add to current, and once current fills up, clear previous, current becomes previous, and previous becomes current. So basically, we are only adding. In this scenario, current and previous don't need to support any fancy operations beyond add() and clear(), so they can probably be either flat or HNSW.

Another solution is to use a single flat index that supports replacement: i.e., we can efficiently replace some entries with other entries. Faiss doesn't seem to support this, but you can either implement something from scratch (just represent the memory as a tensor), or use the other library that @lucidrains mentioned.

lucidrains commented 2 years ago

@igor0 yea, i feel like if we were to go with flat for faiss, there would be no benefit. the whole point is to sell the approach of using approximate knn for long term, non differentiable memory - at least, that was what excited me about the paper initially

lucidrains commented 2 years ago

maybe it would be best to forget about faiss and scann, and just try to roll something with deepmind's HTM (although i think more thought needs to be put into how to generalize HTM to more than just a depth of 1 hierarchy) - or the alternative is to just forget about this repository and focus on https://github.com/lucidrains/routing-transformer and make sure it supports recurrence and that the routing attention can act on a set of non differentiable memories

rom1504 commented 2 years ago

on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable

I don't understand why add_with_ids is needed. custom ids don't do anything particularly interesting. You can either use consecutive ids, or maintain a consecutive/custom ids mapping (as a python dict, or as a numpy array)

you could decide to use faiss.IndexIDMap if you really want add with ids https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#the-indexidmap

rom1504 commented 2 years ago

Masking out the elements doesn't really solve the problem for us because it doesn't open up new space

why is opening new spaces needed? The number of embedding you add at every iteration is pretty small, so the memory use will be limited until you do a few thousands steps. At this point you can rebuild the index

lucidrains commented 2 years ago

on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable

I don't understand why add_with_ids is needed. custom ids don't do anything particularly interesting. You can either use consecutive ids, or maintain a consecutive/custom ids mapping (as a python dict, or as a numpy array)

you could decide to use faiss.IndexIDMap if you really want add with ids https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#the-indexidmap

ohh ok, maybe it could still work then, since faiss can support a ridiculous number of vectors - realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained

i guess the other issue is frequent retraining of indices, since each batch will be maintaining its own index, and everytime a new document comes along, it needs to be cleared and retrained

screenshot below for clarity, just imagine the batch dimension being around 32 - 64

Screenshot from 2022-04-03 13-09-52

igor0 commented 2 years ago

The number of embedding you add at every iteration is pretty small, so the memory use will be limited until you do a few thousands steps. At this point you can rebuild the index

OK, you can add elements to the index until it reaches 2 x intended_size and then compact it down to intended_size. During the compaction, you can choose whatever criteria for eviction, and simply create a new index. That's an alternative to having two indices (previous and current).

i guess the other issue is frequent retraining of indices

I don't think HNSW indices need to be trained, so training wouldn't be an issue. That's also why we shouldn't use clustering-based indices, at least at training time. (The constraints are a bit different at inference time, but let's focus on training time for now.)

igor0 commented 2 years ago

realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained

If we don't need to support very large documents, then all the smart forgetting work becomes unnecessary. We can just clear the memory at the end of each document, so no need to forget individual entries at all. Maybe that's good enough to get past enwik8?

rom1504 commented 2 years ago

indeed hnsw requires no training, that's why I was suggesting it!

rom1504 commented 2 years ago

if your total number of vector is below 1M, hnsw is quite fine if your dimension is 1024, 1M vectors mean 4GB in ram which is quite reasonable

beyond 10M vectors, you'll have to think a bit more, but with some smart eviction and retraining only every N (like 1000) steps, it should be ok

lucidrains commented 2 years ago

indeed hnsw requires no training, that's why I was suggesting it!

TIL!

lucidrains commented 2 years ago

realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained

If we don't need to support very large documents, then all the smart forgetting work becomes unnecessary. We can just clear the memory at the end of each document, so no need to forget individual entries at all. Maybe that's good enough to get past enwik8?

yea that's true, but it would be nice if the work is extended (RL for example)

rom1504 commented 2 years ago

just to give an idea of how much faiss can scale, https://rom1504.github.io/clip-retrieval/ 's backend is currently holding an index of 5B embeddings of dimension 768. It uses 200MB of ram for the index thanks to faiss's memmapping feature. The index is 800GB. The search time is 400ms because it's big and on disk. (if using sharding and in-memory, the search time could be < 1ms)

so for millions of embeddings, everything should be fine

lucidrains commented 2 years ago

ok, let me meditate on this for an hour or two (before reckless execution), since it would require a big refactor of the KNN Memory class

thank you both, and hope you are having a great weekend :)

lucidrains commented 2 years ago

to start off with, i will simply forget about expiring memories altogether, and just throw away the index at the start of a new document (or if it hits some maximum capacity)

lucidrains commented 2 years ago

@igor0 also, just so you know, the enwik8 code in my other repositories won't work out of the box. we'll need to organize the tokens into document, fed into the transformer in order, with no overlapping documents in a batch within each batch-row

lucidrains commented 2 years ago

ok, done! https://github.com/lucidrains/memorizing-transformers-pytorch/commit/81e2d3199861264698e1930e53cf113403118877 max number of memories per batch row set to a quarter of a million tokens (should be enough)

lucidrains commented 2 years ago

also, just for everyone's reference, gradient accumulation will not work as expected with memories for now (as memories need to be aligned with the correct batch row)

igor0 commented 2 years ago

@igor0 also, just so you know, the enwik8 code in my other repositories won't work out of the box.

Yeah, so there are a couple of ways to approach it. On a closer look, the way your other repo (x-transformers?) handled enwik8 was also somewhat simplistic. It looks like the enwik8 dataset consists of XML markup, which then embeds the actual wiki content. As far as I can tell, the TextSamplerDataset simply operates on the XML text, which includes the wiki content, but also is wrapped in XML markup.

The most simplistic way to bring that into memorizing-transformers is to replace random shuffling with a static reshuffling of the text so that an access with a stride of BATCH_SIZE observes contiguous text from the enwik8 XML file. That could look something like this (untested):

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len, batch_size):
        super().__init__()
        self.data = data
        self.seq_len = seq_len
        self.batch_size = batch_size

        # if needed, drop data so that we have an integer multiple of the batch size
        batch_tokens = batch_size * seq_len
        data_len = (len(self.data) // batch_tokens) * batch_tokens
        self.data = self.data[:data_len]

    def __getitem__(self, index):
        batch = (index % self.batch_size)
        remapped_index =  batch * (len(self) // self.batch_size) + (index // self.batch_size)
        start = remapped_index * self.seq_len
        full_seq = self.data[start : start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

This code works, but has two flaws. First, we aren't clearing memory at the end of the document. Second, the XML markup is included in the training. This may be sufficient for early testing, but would need rework for a more serious use.

lucidrains commented 2 years ago

@igor0 so i have some logic to clear the memory on detection of an [sos] token https://github.com/lucidrains/memorizing-transformers-pytorch/blob/main/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L327 but i can also add the corresponding logic to clear (after the iteration) on detection of [eos]

does enwik8 come equipped with start and end tokens? i have to confess i copied the original enwik8 training code from somewhere else and never looked closely :laughing:

lucidrains commented 2 years ago

https://github.com/lucidrains/memorizing-transformers-pytorch/commit/7a69d8bafa1ca52f056a1b2209e98b1a3036c6d8 added auto clearing of memories on detection of [eos] token id

igor0 commented 2 years ago

The info is there, but it's in the XML structure that the data loader currently simply treats as text. enwik8 contains XML markup like this:

<mediawiki ...>
  <page>
     <...>
      <comment>/* Anarchist Communism */  too many brackets</comment>
      <text xml:space="preserve">{{Anarchism}}
... here goes text on anarchism
      </text>

Ideally, the data loader would parse this XML, extract individual documents, and appropriately insert the <|endoftext|> tokens. The data loader I stole from x-transformers is completely ignoring the XML structure and simply treating it as blocks of text.

lucidrains commented 2 years ago

@igor0 ok, probably save that for another day (maybe wednesday), as i'm switching gears to vqgan-vae work :)

lucidrains commented 2 years ago

@igor0 so i think for testing purposes, i'll just fetch sequences of 512 * 10 in length from enwik8 and do an experiment with and without knn memory to see how it fares - mainly, i want to see, base network, base network + xl memories, then base network + knn memory, and finally both memories (xl and knn)

lucidrains commented 2 years ago

the knn memories should be ready for gradient accumulation now, so i think should get some results by end of friday, to figure out if this approach is as advertised

igor0 commented 2 years ago

Awesome, very exciting! And yeah, keeping things simple to prove out the approach sounds like the way to go.

lucidrains commented 2 years ago

@igor0 hey Igor, finally got the training loop going https://github.com/lucidrains/memorizing-transformers-pytorch/blob/main/train.py will let it run for a while to make sure there's no causal leakage, and then do some full tests tomorrow to see how far this memory technique can get us

lucidrains commented 2 years ago

@igor0 do let me know what you find if you end up testing it out :)

igor0 commented 2 years ago

@igor0 awesome! yeah i'll let you know what i find.

lucidrains commented 2 years ago

@igor0 same :) really hoping this paper pans out!

lucidrains commented 2 years ago

https://wandb.ai/lucidrains/memorizing-transformers/reports/memorizing-transformers--VmlldzoxODE4ODc1?accessToken=2ae4y9l100i3kfj3oa0udcsct0r495evo661lkk71w1mve1itlmphn20eiix1aul have some experiments in progress, will update it throughout the day (and potentially smooth out any bugs in the code, if unearthed)

lucidrains commented 2 years ago

it is still way too slow, may have to look into https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html - it seems like one should be able to get this to work with memmaps and faiss indices https://joblib.readthedocs.io/en/latest/auto_examples/parallel_memmap.html

lucidrains commented 2 years ago

i think it worked :) about twice as fast

i'll keep whittling at the performance next week

joblib is really cool

igor0 commented 2 years ago

Cool, awesome progress! Yeah, since each training iteration has to do thousands or tens of thousands of KNN queries (depending on model size), even if each query takes 1 ms, that's seconds or tens of seconds per training iteration. Parallelism can help somewhat, but fine-grained parallelism isn't Python's strength, so there will be a limit to how much it helps.

I think that ultimately, it may make sense to support two modes of KNN:

Then, we can train first in Flat Mode with small memory in order to train the model in a reasonable amount of time. Once converged, then use HNSW Mode to finetune with larger memory, and also use HNSW for inference. IIRC, in the paper they had to train with a smaller memory initially anyways (8192 entries?) in order to get the model to converge.