Hi, here is a pull request for a small speedup where attention is computed using pytorch 2 function "torch.nn.functional.scaled_dot_product_attention" if available.
Makes the optimizer run about 10% faster according to a bit of testing I did
This optimization was essentially copied from a recent version of nanoGPT
Hi, here is a pull request for a small speedup where attention is computed using pytorch 2 function "torch.nn.functional.scaled_dot_product_attention" if available.
Makes the optimizer run about 10% faster according to a bit of testing I did
This optimization was essentially copied from a recent version of nanoGPT