erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
209 stars 24 forks source link

Attention Mask for Packed Sequences (via Attention Bias) #129

Closed xingyaoww closed 7 months ago

xingyaoww commented 8 months ago

Hi @erfanzar,

Thanks for the great repo! It looks really useful for training open-source models on TPU and GPUs!

I wonder if it is easy to implement the feature that allows users to pass in packed sequences (see this) that allows us to maximize the hardware utilization?

The idea is that the user can provide a list of attention_mask_in_length, then under the hood, we tweak the attention_mask OR adjust the attention bias accordingly, so that two different examples cannot attend to each other.

I see that you support attention bias for flash attention here, which could be a good starting point (e.g., just set the attention across sequences to be -inf). However, I find that the flash_func has a different function signature:

  1. For GPU, flash_func is flash_attn_gpu.mha, which does not really has the argument bias according to here.
  2. While [For TPU] (flash_attn_tpu.flash_attention), that argument exists.

Is this intended? If not, do we need to fix the GPU's implementation to be the same as TPU's? Is this the best way to implement such packing feature?

erfanzar commented 8 months ago

Hello, and thanks for using EasyDeL.

It would be a cool feature to add is there any code related to someone who has implemented this before or not?

for GPU flash attention I want to use JAX-triton to port flash attention2 implementation to jax-Pallas and use that it would be an easier, faster, and better option to use that instead of re-creating flash attention 2 for Pallas jax GPUs.

is there any related paper released for additional information on implementation?, as I understood from what has been explained here it's not that much of a hard thing to implement we just need a 3D attention mask or causal mak which is not hard to implement for GPUs, but for TPUs there has to be some little tricks for sharding and multi-hosting implementation.

xingyaoww commented 8 months ago

Hi @erfanzar ,

Thanks for your response! I think using Jax-triton is a great idea!!

Regarding specific implementation for packed attention: Here is what I've implemented for Megatron-LLM based on the issue i shared earlier.

However, I don't know if it will be super helpful since it directly uses flash-attention's flash_attn_varlen_qkvpacked_func. I quickly search over the flash attention repo and find the implementation of this function: here and here. That function seems to rely on varlen_fwd in their CUDA kernel implementation, which would be too low level i feel.

We only have fused_attention_kernel function in jax-triton, not sure if we could repurpose that function in a way to implement this cleverly.

Also, I wonder whether you have an estimated timeline for integrating jax-triton's attention into EasyDel? I might also be able to take a look at this if you integrated jax-triton into the repo :-)!

erfanzar commented 8 months ago

I'll work on that ASAP and actually, I am interested in working on this since there are available CUDA and triton implementations I'll first start with implementing that on GPU via jax-triton and then create a Pallas version for that. and I'll be happy to integrate JAX-triton attention and other low-level GPU functions into EasyDeL for sure, but it would take some time right now I have to debug the created flash and ring attention created for TPUs and implement the SFTTrainer and after that, I can work on this but right now my priority is to fix the upcoming issue, but I guess I'll start integrating JAX-triton to FJFormer in next week and then connect FJFormer to EasyDeL