ecmwf / anemoi-models

Apache License 2.0
34 stars 12 forks source link

Remove conditional import of `flash_attn` #43

Open b8raoult opened 2 months ago

b8raoult commented 2 months ago

What happened?

Inference crashes without any meaningful error message (just exists back to the shell) when training was done with flash_attn installed and not inference, and vice-versa.

What are the steps to reproduce the bug?

Two ways to reproduce the issue

  1. Train a model with flash_attn installed, run inference with flash_attn not installed
  2. Train the model without flash_attn, run inference in an environment where flash_attn is installed

This can also apply to a training that is restarted from a checkpoint.

Version

all

Platform (OS and architecture)

any

Relevant log output

try:
    from flash_attn import flash_attn_func as attn_func
except ImportError:
    from torch.nn.functional import scaled_dot_product_attention as attn_func

    _FLASH_ATTENTION_AVAILABLE = False
else:
    _FLASH_ATTENTION_AVAILABLE = True

Accompanying data

No response

Organisation

No response