lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Enable flash attention does not support BFloat16? #254

Closed Kaimary closed 3 months ago

Kaimary commented 4 months ago

When I use torch.bfloat16, enable attn_flash will cause the following error,

File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/x_transformers/attend.py", line 215, in flash_attn out = F.scaled_dot_product_attention( File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/x_transformers/attend.py", line 275, in forward return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/x_transformers/x_transformers.py", line 974, in forward out, intermediates = self.attend( File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/x_transformers/x_transformers.py", line 1390, in forward out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/x_transformers/x_transformers.py", line 1767, in forward x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/x_transformers/autoregressive_wrapper.py", line 284, in forward logits, cache = self.net( File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1855, in forward loss = self.module(inputs, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/home/kaimary/miniconda3/envs/tinysql/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/kaimary/codes/train_causal_lm.py", line 250, in train loss = model(batch["input_ids"]) File "/home/kaimary/codes/train_causal_lm.py", line 296, in train(opt) RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::BFloat16 instead.

Why does Flash Attention 2.0 in transformers library only supports torch.float16 and torch.bfloat16 dtypes, but attn_flash only supports float32?

lucidrains commented 4 months ago

@Kaimary hello! thank you for raising this issue

could you check if 1.30.2 resolves the problem?