Closed kevinstephano closed 2 years ago
Fwiw, we currently already use Nvfuser for this by simply decomposing it: https://github.com/pytorch/functorch/blob/3ca3f10bf101ee3490753c9ded0a4f7bbcb32488/functorch/_src/decompositions.py#L235
Imo, ideally, nvfuser could simply reuse these decompositions. Alternately, hopefully this makes it easier to at least match pytorch numerics :P
I appreciate this is generally true with AOT autograd, but not with all ~backends~ integrations. There's also the questions of if the decompositions are optimal, and aligned with the codegenerators expectations which is still to be proven out.
For example @rdspring1 is trying GeLU through AOT Autograd, and the decomposition is ~2.5x slower, though we're still trying to debug why.
aligned with the codegenerators expectations which is still to be proven out.
That's fair, and indeed, part of the vision of decompositions is that backends should always have the choice whether they want to decompose operators - if you have some special settings/configs that map to layer_norm, we shouldn't take that away.
this is generally true with AOT autograd, but not with all backends integrations
I'm not sure this is as true - it's easy enough to use these decompositions with torchscript, and in fact, I already wrote a hacky prototype that does so.
(just realized that I didn't finish my comment)
I'll follow up on the gelu thing on slack - in our experience it seems to work fine.
Sounds like gelu got resolved, so now need to see if normalizations can hold up to this as well ❗
🚀 Feature
There is a performance opportunity to fuse the Bias for the projection linear layer to LogSoftmax that can be expensive as the hidden size out of the projection is
30258
. We just need LogSoftmax to be added to the parser.This would be a simple example:
Here are a couple examples of the in model cases. The
LogSoftmax
is insidenn.CrossEntropyLoss
: https://github.com/kevinstephano/simple_dl_models/blob/main/bert_model_1_layer_no_opt.py https://github.com/kevinstephano/simple_dl_models/blob/main/bert_model.py https://github.com/kevinstephano/simple_dl_models