facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.66k stars 616 forks source link

NaNs when training with `attn_bias` (f32) #684

Open zen-d opened 1 year ago

zen-d commented 1 year ago

❓ Questions and Help

Hi, I pass in the attn_bias to xformers.ops.memory_efficient_attention, but meet the following error

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(831, 43, 32, 8) (torch.float32)             
     key         : shape=(831, 43, 32, 8) (torch.float32)
     value       : shape=(831, 43, 32, 8) (torch.float32)  
     attn_bias   : <class 'torch.Tensor'>
     p           : 0.0                                                                                                                                   
`flshattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})                                                                                     
    attn_bias type is <class 'torch.Tensor'>
`tritonflashattF` is not supported because:                                                                                                              
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    attn_bias type is <class 'torch.Tensor'>                                                                                                             
`cutlassF` is not supported because:
    attn_bias.shape[-1] % 8 != 0                                                                                                                         
`smallkF` is not supported because:                                         
    bias with non-zero stride not supported    

In my case, attn_bias is indispensable and it is hard to always satisfy that attn_bias.shape[-1] % 8 == 0, so how could I benefit from this repo? Thanks.

danthe3rd commented 1 year ago

Hi, Thank for opening this issue. That's something we can work on (see https://github.com/facebookresearch/xformers/issues/683). What type of bias do you need? Is it a learnable bias

zen-d commented 1 year ago

@danthe3rd Thanks a lot for your prompt reply! #683 is highly related. In that thread I notice you may work on it https://github.com/facebookresearch/xformers/issues/683#issuecomment-1458153308. First, may I know when the support for a attn_bias of torch.Tensorwith attn_bias.shape[-1] % 8 != 0 is scheduled? Would it be a very recent plan? Second, if you could also support a learnable attn_bias, it would become more attractive.

danthe3rd commented 1 year ago

The bias is currently learnable :) We just need to add this padding support. Hopefully we can get that out next week

zen-d commented 1 year ago

Wow, fantastic! Look forward to seeing the padding support soon to relax the shape constraint.

danthe3rd commented 1 year ago

It's merged in https://github.com/facebookresearch/xformers/commit/b6be33aecb5297f3f994568cf29e194a75e47667

zen-d commented 1 year ago

@danthe3rd Thanks! Looks good, but I don't have free GPUs temporarily. I will try on the new feature ASAP.

zen-d commented 1 year ago

@danthe3rd By following these hints to do padding and slicing, I'm able to run the model now. The memory burden is significantly alleviated. Thanks for your awesome job! I will continue to monitor the training process and the final accuracy.

HINT: To use an attn_bias with a sequence length that is not a multiple of 8, \ you need to ensure memory is aligned by slicing a bigger tensor. \ Example: use attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5] instead of torch.zeros([1, 1, 5, 5])

zen-d commented 1 year ago

Unfortunately, the training diverges in the middle (loss becomes NaN), which did not happen in the original attention-based model. Would you like to share some insights about that? Thanks.

danthe3rd commented 1 year ago

I don't have specific idea for this, but you can detect more precisely where the nan is coming from with the anomaly detection:

torch.autograd.set_detect_anomaly(mode=True, check_nan=True)
zen-d commented 1 year ago

Thanks for providing the suggestion. The only difference is the attention implementation in this controlled experiment, but I am not sure of the specific reason temporarily. I will dive deep into the issue. :)

danthe3rd commented 1 year ago

Also - this is running in f32 it looks like? Otherwise you might want to try to train with f32 to see if it's related to the numerical precision

zen-d commented 1 year ago

Yes, for safety, I am training with FP32 numerical precision now. (Similar to my experience, AMP training seems to have more chance of NaN for Transformer-based models.)

Shannen3206 commented 1 year ago

Yes, for safety, I am training with FP32 numerical precision now. (Similar to my experience, AMP training seems to have more chance of NaN for Transformer-based models.)

I meet the same question, and i found that use fp16 can solve this problem.

Shannen3206 commented 1 year ago

@danthe3rd By following these hints to do padding and slicing, I'm able to run the model now. The memory burden is significantly alleviated. Thanks for your awesome job! I will continue to monitor the training process and the final accuracy.

HINT: To use an attn_bias with a sequence length that is not a multiple of 8, you need to ensure memory is aligned by slicing a bigger tensor. Example: use attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5] instead of torch.zeros([1, 1, 5, 5])

Hi, I found that use this method may cause the inference speed lower.#853 Do you have any good way?