Open pfeatherstone opened 1 year ago
Or at least, it would be nice if you could specify attn_flash
in forward()
not in the constructor of the transformer. Then during training i can set to True, and during ONNX export (also during training), i can set to False. Alternatively, I can recreate a model and copy the weights during export.
So, the reason for this is: I would like to train with scaled_dot_product_attention()
but when I'm doing my checkpointing, i export to ONNX, which doesn't support Flash. So I have to create a fresh model with Flash disable, port the weights, then do the ONNX export. If attn_flash
were specified in the forward method, i wouldn't have to copy the model, i could just change the arguments.
Also, in the forward pass, you could check whether Flash is possible depending on:
sparse_topk == False
talking_heads == False
torch.__version__ >= 2.0.0
attn_mask is not None
or is_causal==True
(not both)isinstance(rel_pos, RelativePositionBias) == False
isinstance(rel_pos, DynamicPositionBias) == False
not residual_attn and not cross_residual_attn
If either of these is false, set flash = False
.
Then basically you never have to worry about it and the model will always do the right thing without the user having to worry about it.
Rather than manually specifying
attn_flash
, why not allow the code to adaptively figure out if possible in theforward()
method? As far as i can tell you can use it when there is no fancy relative positional bias and not doing an ONNX export. I feel like it would be a case of doing something like:I don't know if there is way to introspect if we're doing an ONNX export. If not, we could simply add a parameter to the
forward()
method likeonnx_exporting