Inference crashes without any meaningful error message (just exists back to the shell) when training was done with flash_attn installed and not inference, and vice-versa.
What are the steps to reproduce the bug?
Two ways to reproduce the issue
Train a model with flash_attn installed, run inference with flash_attn not installed
Train the model without flash_attn, run inference in an environment where flash_attn is installed
This can also apply to a training that is restarted from a checkpoint.
Version
all
Platform (OS and architecture)
any
Relevant log output
try:
from flash_attn import flash_attn_func as attn_func
except ImportError:
from torch.nn.functional import scaled_dot_product_attention as attn_func
_FLASH_ATTENTION_AVAILABLE = False
else:
_FLASH_ATTENTION_AVAILABLE = True
What happened?
Inference crashes without any meaningful error message (just exists back to the shell) when training was done with
flash_attn
installed and not inference, and vice-versa.What are the steps to reproduce the bug?
Two ways to reproduce the issue
flash_attn
installed, run inference withflash_attn
not installedflash_attn
, run inference in an environment whereflash_attn
is installedThis can also apply to a training that is restarted from a checkpoint.
Version
all
Platform (OS and architecture)
any
Relevant log output
Accompanying data
No response
Organisation
No response