I have encountered a performance problem when executing a model that utilizes Flash Attention using torch.jit trace with C++ libtorch on Windows. The inference speed on Windows is 2 to 3 times slower than on Linux, leading me to question whether Flash Attention is genuinely being utilized during the operation.
While there are no warnings in the stable version(2.0.1) of PyTorch, when I use the nightly version(2.2.0.dev20230920+cu118) I receive the following warning:
[W sdp_utils.cpp:234] Warning: Torch was not compiled with flash attention. (function use_flash_attention)
I'll provide more detailed steps to describe the bug:
My model has a module that employs Flash Attention, as shown in the code below:
with sdp_kernel(**self.backend_map):
a = F.scaled_dot_product_attention(q.half(), k.half(), v.half()).transpose(2, 3)
2. I then convert it to a torch.jit traced model, as illustrated below:
```python
# Trace model
with torch.no_grad():
traced_script_module = torch.jit.trace(
model,
input,
strict=True
)
# Save model
traced_script_module.save("traced_model.pt")
🐛 Describe the bug
I have encountered a performance problem when executing a model that utilizes Flash Attention using torch.jit trace with C++ libtorch on Windows. The inference speed on Windows is 2 to 3 times slower than on Linux, leading me to question whether Flash Attention is genuinely being utilized during the operation.
While there are no warnings in the stable version(2.0.1) of PyTorch, when I use the nightly version(2.2.0.dev20230920+cu118) I receive the following warning:
I'll provide more detailed steps to describe the bug:
Run scaled dot product attention
with sdp_kernel(**self.backend_map): a = F.scaled_dot_product_attention(q.half(), k.half(), v.half()).transpose(2, 3)
I download libtorch for Windows.
I compile and then execute using libtorch
torch::jit::script::Module module; module = torch::jit::load(model_path); at::Tensor outputTensor= module(inputs).toTensor().squeeze(0);