Closed SeanNaren closed 2 years ago
Would be awesome to get @Tom94 opinion if hes' around :D
Hi there, these results look pretty much as expected, given the setup.
To benefit from fully fused MLPs, you need to
FullyFusedMLP
rather than CutlassMLP
(the latter not being fused at all). AFAIK, PyTorch uses CUTLASS' GEMM routines under the hood as well, which likely explains the very similar performance numbers.Cheers!
All makes sense was really easy to work with the framework, fantastic work @Tom94!
I think the intuition that @blefaudeux was that fusing the MLP inside the transformer blocks would give us a pretty hefty speed up. would you agree here @Tom94? I still think it's worth investigating
All makes sense was really easy to work with the framework, fantastic work @Tom94!
I think the intuition that @blefaudeux was that fusing the MLP inside the transformer blocks would give us a pretty hefty speed up. would you agree here @Tom94? I still think it's worth investigating
I was not sure about what tinyNN was capable of handling, basically "fusing the MLP" is easier said than done (I tried some of it myself with triton, definitely a mouthful even if you can fuse some steps decently :)). As far as I understood I overestimated what size tinyNN was capable of handling, may not be a great fit for you @SeanNaren in the end for the embedding sizes that you're targeting (>> 128) (@Tom94 will obviously know more)
was great to try regardless, I learnt something new :)
Closing this for now, as I think it's definitive!
Ben suggested checking out the MLP layer in https://github.com/NVlabs/tiny-cuda-nn as it's a fused version of a typical MLP.
I created a script to benchmark this based on just the MLP: https://github.com/SeanNaren/SmallScience/blob/feat/profiling/tiny.py
I ran the script on an A100 machine using PyTorch 1.11 and CUDA 11.3.
It seems for very small dimensions that the tinycudann MLP works well, however after a certain point it loses compared to our typical nn.Linear:
With autocast enabled (log scale):
Without autocast enabled (log scale):