Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.27k stars 1.34k forks source link

test_vit.py not working #367

Open macrocredit opened 1 year ago

macrocredit commented 1 year ago

Hi - I tried out the test_vit.py file and it doesnt seem to work; primarily I have the error: "TypeError: MHA.init() got an unexpected keyword argument 'bias'":

For test_vit(optimized=True, fused_mlp=True), I got the following error: File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 303, in vit_base_patch16_224 model = VisionTransformer(**model_kwargs) File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 141, in init self.patch_embed = embed_layer( File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/layers/patch_embed.py", line 41, in init raise ImportError('fused_dense is not installed') TypeError: MHA.init() got an unexpected keyword argument 'bias'

For test_vit(optimized=True, fused_mlp=False), I got the following error: File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 303, in vit_base_patch16_224 model = VisionTransformer(**model_kwargs) File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 141, in init self.patch_embed = embed_layer( File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/layers/patch_embed.py", line 41, in init raise ImportError('fused_dense is not installed') TypeError: MHA.init() got an unexpected keyword argument 'bias'

For test_vit(optimized=False, fused_mlp=False), I got the following error: model = VisionTransformer(**model_kwargs) File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 163, in init self.blocks = nn.ModuleList([create_block( File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 163, in self.blocks = nn.ModuleList([create_block( File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 57, in create_block block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/modules/block.py", line 71, in init self.mixer = mixer_cls(dim) TypeError: MHA.init() got an unexpected keyword argument 'bias'

For test_vit(optimized=False, fused_mlp=True), I got the following error: File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 163, in self.blocks = nn.ModuleList([create_block( File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 55, in create_block mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp) File "/home/cirrascale/anaconda3/envs/XYZ/lib/python3.10/site-packages/flash_attn/models/vit.py", line 45, in create_mlp_cls mlp_cls = partial(FusedMLP, hidden_features=inner_dim) TypeError: MHA.init() got an unexpected keyword argument 'bias'

Thanks.

macrocredit commented 1 year ago

OK; I think I fixed the error; in models/vit.py (probably related to some legacy code (?)), I changed the below:

# OLD VERSION
# mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
#                     dropout=attn_drop, fused_bias_fc=fused_bias_fc,
#                     use_flash_attn=use_flash_attn)

NEW VERSION

mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, qkv_proj_bias=qkv_bias,
                    dropout=attn_drop, fused_bias_fc=fused_bias_fc,
                    use_flash_attn=use_flash_attn)

However, when I ran the test, I failed the test below:

FAILED tests/models/test_vit.py::test_vit[False-True] - AssertionError: assert 0.030423879623413086 < (4 0.00498652458190918) FAILED tests/models/test_vit.py::test_vit[True-True] - AssertionError: assert 0.030423879623413086 < (4 0.00498652458190918) 2 failed, 2 passed, 2 warnings in 35.24s

Also, last question from me. It seems like the time running ViT flash-attention model is slower than timm and the model-ref in the test file. Below is my test run time on A-5000 node:

*** test_vit(optimized=True, fused_mlp=True) Time of flash-attn model: 15.017242670059204s Time of model_timm-attn model: 11.54206895828247s Time of model_ref-attn model: 11.049510478973389s

*** test_vit(optimized=False, fused_mlp=False) Time of flash-attn model: 19.591643571853638s Time of model_timm-attn model: 11.633978843688965s Time of model_ref-attn model: 11.028218984603882s

Are these also expected like the above point?

Thanks.

githubabin commented 9 months ago

OK; I think I fixed the error; in models/vit.py (probably related to some legacy code (?)), I changed the below:

# OLD VERSION
# mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
#                     dropout=attn_drop, fused_bias_fc=fused_bias_fc,
#                     use_flash_attn=use_flash_attn)

NEW VERSION mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, qkv_proj_bias=qkv_bias, dropout=attn_drop, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)

However, when I ran the test, I failed the test below:

FAILED tests/models/test_vit.py::test_vit[False-True] - AssertionError: assert 0.030423879623413086 < (4 0.00498652458190918) FAILED tests/models/test_vit.py::test_vit[True-True] - AssertionError: assert 0.030423879623413086 < (4 0.00498652458190918) 2 failed, 2 passed, 2 warnings in 35.24s

Also, last question from me. It seems like the time running ViT flash-attention model is slower than timm and the model-ref in the test file. Below is my test run time on A-5000 node:

*** test_vit(optimized=True, fused_mlp=True) Time of flash-attn model: 15.017242670059204s Time of model_timm-attn model: 11.54206895828247s Time of model_ref-attn model: 11.049510478973389s

*** test_vit(optimized=False, fused_mlp=False) Time of flash-attn model: 19.591643571853638s Time of model_timm-attn model: 11.633978843688965s Time of model_ref-attn model: 11.028218984603882s

Are these also expected like the above point?

Thanks.

I got the same result on 3080Ti GPU with you. And did you find out the cause?