lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.78k stars 417 forks source link

Feature request: don't specify attn_flash. Select when possible #173

Open pfeatherstone opened 1 year ago

pfeatherstone commented 1 year ago

Rather than manually specifying attn_flash, why not allow the code to adaptively figure out if possible in the forward() method? As far as i can tell you can use it when there is no fancy relative positional bias and not doing an ONNX export. I feel like it would be a case of doing something like:

if not onnx_export and not t5_rel_bias and not ...:
     return F.scaled_dot_product_attention(...)
else:
    return some_vanilla_sdpa(...)

I don't know if there is way to introspect if we're doing an ONNX export. If not, we could simply add a parameter to the forward() method like onnx_exporting

pfeatherstone commented 1 year ago

Or at least, it would be nice if you could specify attn_flash in forward() not in the constructor of the transformer. Then during training i can set to True, and during ONNX export (also during training), i can set to False. Alternatively, I can recreate a model and copy the weights during export.

pfeatherstone commented 1 year ago

So, the reason for this is: I would like to train with scaled_dot_product_attention() but when I'm doing my checkpointing, i export to ONNX, which doesn't support Flash. So I have to create a fresh model with Flash disable, port the weights, then do the ONNX export. If attn_flash were specified in the forward method, i wouldn't have to copy the model, i could just change the arguments. Also, in the forward pass, you could check whether Flash is possible depending on:

If either of these is false, set flash = False. Then basically you never have to worry about it and the model will always do the right thing without the user having to worry about it.