tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

EmbeddingBag and Product-Key Memory Layers #2201

Closed Rocketknight1 closed 2 years ago

Rocketknight1 commented 4 years ago

Describe the feature and the current behavior/state. FAIR have a cool paper where they introduce Product-Key Memory Layers - these are layers that can add a huge number of parameters (100M-1B) to a network with a very minimal compute overhead.

Unfortunately, implementing them efficiently depends on the EmbeddingBag layer from Pytorch. This layer basically does a gather op followed by a weighted sum across the final dimension of the gather indices.

It is trivial to implement this op as a composition of two or three ops in Tensorflow, but doing so requires you to materialize the output of the gather, which in the case of Product-Key Memory layers is enormous, and usually blows out my GPU RAM. By combining these ops into a single efficient call, EmbeddingBag avoids ever materializing the extremely large pre-sum gather output. There's no efficient way to do the same in Tensorflow without a custom op.

I've already gotten a CUDA and (single-threaded) CPU implementation of EmbeddingBag working locally using the custom-op repo and associated docker image. I've verified correctness by comparing outputs and gradients to those from the manual composition of ops, and speed and memory usage are vastly improved. I could also contribute a TF implementation of the Product-Key Memory layer itself if desired.

Relevant information

Which API type would this fall under (layer, metric, optimizer, etc.) Layer

Who will benefit with this feature? People who want to squeeze loads of parameters into their model while maintaining fast throughput and aren't worried about overfitting. The paper used it for big autoregressive NLP Transformers, but I suspect you could deploy it in a lot of other places too.

Any other info. I have only implemented the portions of EmbeddingBag necessary for Product-Key Memory layers.

bhack commented 4 years ago

Have you tried to see if the subgraoh could be already fused https://www.tensorflow.org/lite/convert/operation_fusion?hl=en#wrap_the_composite_operation_in_a_tffunction?

It could be nice to have fusion in TF and the layer here.

bhack commented 4 years ago

Check also https://github.com/tensorflow/tensorflow/issues/32675

bhack commented 4 years ago

/cc @tanzhenyu @dynamicwebpaige for ecosystem pre-check

Rocketknight1 commented 4 years ago

I don't believe subgraph fusion is possible - I tried using XLA and it didn't resolve the memory issues. I haven't tried TFLite but I would be surprised if this op could be fused automatically, as I had to implement several tricks. In particular, the gradient for the values tensor that is gathered from cannot be an IndexedSlices object, because EmbeddingBag (especially as used in PKM layers) usually gathers many more slices from the values tensor than a normal call to tf.gather(), and so the size of a naive IndexedSlices gradient could be several times larger than the values tensor itself!

Computing the dense values gradient efficiently requires some temp memory and a call to thrust::sort_by_key, plus some custom logic to ensure that CUDA can distribute work efficiently without multiple threads writing to the same entry in the values gradient (This is similar to the PyTorch implementation). I do not think any automatic operator fusion would be able to do this correctly.

Also, I commented in that issue that you linked - after further testing my solution without a custom op turned out to still have huge memory usage compared to the custom op solution, and much worse performance too.

bhack commented 4 years ago

I haven't tried TFLite but I would be surprised if this op could be fused automatically

It was not related to the TFLite documentation strictly but it is used just to have a documentation pinpoint to the composite ops fusion topic.

Also, I commented in that issue that you linked - after further testing my solution without a custom op turned out to still have huge memory usage compared to the custom op solution, and much worse performance too.

This was just to notify other maintainers about the full history.

Rocketknight1 commented 4 years ago

No problem! Also, one question - although I'm suggesting this as a PR for tf/addons, I'd ideally like to get it into TF itself, since it's already an op in Pytorch and there are a few Transformer derivatives that are using it, which as a result can't be fully implemented in TF.

Is going via tf/addons the right way to do this?

bhack commented 4 years ago

Yes there Is any specific protocol on which repository to start a FR but in Addons when we receive a feature contribution issue proposal we tag the issue as ecosystem-review to check if TF core, keras-cv, keras-nlp, model garden or any other ecosystem repo Is already working internally on the same feature or they could be interested to have the PR in their repo.

Rocketknight1 commented 4 years ago

Cool! Also, if we do include it as a PR to either tf core or addons, I propose naming it anything except "EmbeddingBag". "gather_sum" or "gather_reduce_sum" are much clearer about what it actually is.

Rocketknight1 commented 4 years ago

Just gonna bump this to make sure it doesn't get lost

bhack commented 4 years ago

/cc Gently ping for @tanzhenyu @dynamicwebpaige for ecosystem pre-check

bhack commented 3 years ago

@tomerk Can you help us to route this Ecosystem review?

tomerk commented 3 years ago

Checked in w/ @ematejska. The internal TF API owners & oss outreach will start a regular (bi-weekly?) review of PRs marked w/ the ecosystem-review label.

Rocketknight1 commented 3 years ago

Pinging this again to make sure it doesn't get lost!

tomerk commented 3 years ago

Thanks for the ping! Yes we are currently following up on this to make sure these reviews happen & in a timely manner.

tomerk commented 3 years ago

Notes from ecosystem review: This appears to be fairly recent work that isn't on core or KerasNLP's roadmap. Seems fine to have in addons from our perspective.

Rocketknight1 commented 3 years ago

I think that's definitely the right choice for the PKM layer - but do you think EmbeddingBag should be in TF core? It's a fairly simple op that has other uses besides PKM layers, and it's in PyTorch core.

tomerk commented 3 years ago

Oh hmm I think we missed that. We can take another look after the holidays

tomerk commented 3 years ago

Notes from ecosystem review: KerasNLP would probably be a better fit than core for EmbeddingBag, but it's not on the KerasNLP roadmap either so addons seems fine for now.

bhack commented 3 years ago

Notes from ecosystem review: KerasNLP would probably be a better fit than core for EmbeddingBag, but it's not on the KerasNLP roadmap either so addons seems fine for now.

@tomerk With this (but also with KerasCV) do you mean that a PR will not be reviewed and merged there cause it is not in the roadmap or that is it up to the user to submit the PR in this repository or KerasNLP?

tomerk commented 3 years ago

Checked in with @tanzhenyu. In this case KerasCV/KerasNLP don't have infrastructure in place for custom ops so this would have to go in addons. Generally speaking though if it's widely used enough and could fit in KerasNLP/KerasCV (and does not include custom ops), KerasNLP & KerasCV would welcome these sorts of user contributions even if they aren't actively on the roadmap.

bhack commented 3 years ago

Checked in with @tanzhenyu. In this case KerasCV/KerasNLP don't have infrastructure in place for custom ops so this would have to go in addons. Generally speaking though if it's widely used enough and could fit in KerasNLP/KerasCV (and does not include custom ops), KerasNLP & KerasCV would welcome these sorts of user contributions even if they aren't actively on the roadmap.

I don't know if generally the ecosystem-review process could expand on this. I meant other then a roadmap check we could extend the activity over a quick "best repository fit".

Something like:

When the infra will be ready for standalone Keras or KerasCV/KerasNLP (as now they are still on the model garden infra) and we have a python only PR it could be interesting for us to know if there is any interest in the ecosystem to review and merge a PR related to an issue under ecosystem review.

So that we could try to best allocate our "very limited" voluntary resources here on PRs that are not in the ecosystem Roadmap but also without a potential reviewers from the TF team in other repos.

What do you think?

tomerk commented 3 years ago

Yeah that seems like a very reasonable approach. We'll try to answer that 'best-fit' question for you for future python-only things.

bhack commented 3 years ago

/cc Asking for a feedback for the compositional ops lowering vs custom ops to @taylanbil as https://github.com/pytorch/xla/issues/2403

/cc @JackCaoG

taylanbil commented 3 years ago

This discussion is lower level in the stack than where my expertise lies, so I cannot comment unfortunately. On torch_xla using EmbeddingBag lowered to XLA, I'll defer to @JackCaoG.

Rocketknight1 commented 3 years ago

By the way, the op is ready to go (EmbeddingBag at least - PKM doesn't need any CUDA and it's much simpler to put it wherever). Let me know once you know where you want me to put the PR!

bhack commented 3 years ago

By the way, the op is ready to go (EmbeddingBag at least - PKM doesn't need any CUDA and it's much simpler to put it wherever). Let me know once you know where you want me to put the PR!

I think that we could have an EmbeddingBag only PR.

But I still want to wait if @JackCaoG has any feedback for a composite efficient EmbeddingBag so that we more device coverage e.g. Colab/TPU or TFlite.

JackCaoG commented 3 years ago

Hi @bhack for pt/xla we used the pytorch native Embedding + Reduce to implement the EmbeddingBag. I think there are rooms for improvement but we didn't have any immediate plan to provide a more efficient lowering.

bhack commented 3 years ago

Hi @bhack for pt/xla we used the pytorch native Embedding + Reduce to implement the EmbeddingBag. I think there are rooms for improvement but we didn't have any immediate plan to provide a more efficient lowering.

Thanks do you have a code reference for that?

What are the current limits of that composition for CPU/GPU other then TPU?

JackCaoG commented 3 years ago

https://github.com/taylanbil/dlrm/blob/tpu/tools/xla_embedding_bag.py is our EmbeddingBag workaround for now. We use the same lowering for CPU/GPU/TPU, we didn't observe any limits except the speed might not be ideal.

bhack commented 3 years ago

@Rocketknight1 can you take a look at @JackCaoG EmbeddingBag. It Is Pytorch but we could translate with Tensorflow ops.

Rocketknight1 commented 3 years ago

That code materializes the gather - self.embtable is an nn.Embedding, which is called with the input indices. It will have a huge performance/memory impact if used with a large indices tensor, as in PKM layers.

taylanbil commented 3 years ago

That was never intended to be a performant implementation, rather just to get something working quickly.

bhack commented 3 years ago

We use the same lowering for CPU/GPU/TPU, we didn't observe any limits except the speed might not be ideal

@JackCaoG Is not Ideal or not usabile for large index tensors?

Cause If the Memory and Speed impact It Is so huge for that case we need to have custom c++ and CUDA ops and exclude TPU

Rocketknight1 commented 3 years ago

A quick back of the envelope calculation might illustrate this:

If we assume a large (e.g. BERT) Transformer model, similar to the ones used in the PKM paper, the hidden dimension will be 1024. If we use standard settings from BERT training (batch size 32, sequence length 512) plus the standard settings from the PKM paper (4 heads, top 32 embeddings selected per head), then the output of the gather operation before reduction will be a float32 Tensor with dimensions (32, 512, 32 * 4, 1024). This will take up 8GB of GPU memory before any gradient tensors etc. are taken into account.

With a custom op, this is not materialized, and we only create the output after reduction. This will have dimensions (32, 512, 1024), which is only 64mb at float32 precision.

bhack commented 3 years ago

@seanpmorgan @WindQAQ: Just to summarize:

We need to evaluate this for:

WindQAQ commented 3 years ago
  1. For EmbeddingBag, we can have both python and C++ implementation. Actually, even in TF core, not all C++ ops are supported with XLA JIT compilation. If we provide python fallback, it's good enough for users who want to use TPU IMO (just like what XLAEmbeddingBag does).
  2. For PKM, we should wait for more citations if the limitation is a hard threshold.
bhack commented 3 years ago

For 1. I know but reading "This is on our todo list" in https://github.com/pytorch/xla/issues/2403#issuecomment-668887911 It is more about an evaluation if we could spent time to review and maintain the custom ops maintainership with the maintainership overhead that cusotm ops have in the meantime or its really near in their TODO list. It Is why I've mentioned @JackCaoG

bhack commented 3 years ago

@Rocketknight1 I think if we don't get any other feedback today you could start to make a PR for EmbeddingBag.

If we assume a large (e.g. BERT) Transformer model, similar to the ones used in the PKM paper, the hidden dimension will be 1024. If we use standard settings from BERT training (batch size 32, sequence length 512) plus the standard settings from the PKM paper (4 heads, top 32 embeddings selected per head), then the output of the gather operation before reduction will be a float32 Tensor with dimensions (32, 512, 32 * 4, 1024). This will take up 8GB of GPU memory before any gradient tensors etc. are taken into account.

For this in the mean time do you have a compositional small example Colab to share?

Rocketknight1 commented 3 years ago

I've created a preliminary PR at #2352 . Don't merge it yet, there's still cleanup to be done! However, it does 'work', and you can build and test that fork and try the op out.

bhack commented 3 years ago

I am also trying to investigate with the Jax team as they have the same problem. We are all on the same boat cause missing XLA support. https://github.com/google/jax/issues/3206

AlexanderLavelle commented 1 year ago

@Rocketknight1 I am curious regarding the use of EmbeddingBag with mixed precision. Is this simple to implement? (from TF Addons)


Incompatible type conversion requested to type 'float32' for AutoCastVariable which is casted to type 'float16'

Call arguments received by layer 'embedding_bag_1' (type EmbeddingBag):
  • indices=tf.Tensor(shape=(8, 505, 11), dtype=int64)
  • weights=None

In addition, how would you recommend treating a batch? It looks like

tf.stack(list(map(eb, encoded.to_tensor())))

will retain the batch and positions correctly, but perhaps this is not the most efficient?

Thank you!