SeanNaren / min-LLM

Minimal code to train a Large Language Model (LLM).
MIT License
164 stars 8 forks source link

Investigating tinycudann MLP layer #8

Closed SeanNaren closed 2 years ago

SeanNaren commented 2 years ago

Ben suggested checking out the MLP layer in https://github.com/NVlabs/tiny-cuda-nn as it's a fused version of a typical MLP.

I created a script to benchmark this based on just the MLP: https://github.com/SeanNaren/SmallScience/blob/feat/profiling/tiny.py

I ran the script on an A100 machine using PyTorch 1.11 and CUDA 11.3.

It seems for very small dimensions that the tinycudann MLP works well, however after a certain point it loses compared to our typical nn.Linear:

With autocast enabled (log scale):

Screenshot 2022-04-08 at 22 26 29 Screenshot 2022-04-08 at 22 26 29

Without autocast enabled (log scale):

Screenshot 2022-04-08 at 22 26 29 Screenshot 2022-04-08 at 22 26 29
SeanNaren commented 2 years ago

Would be awesome to get @Tom94 opinion if hes' around :D

Tom94 commented 2 years ago

Hi there, these results look pretty much as expected, given the setup.

To benefit from fully fused MLPs, you need to

  1. Use FullyFusedMLP rather than CutlassMLP (the latter not being fused at all). AFAIK, PyTorch uses CUTLASS' GEMM routines under the hood as well, which likely explains the very similar performance numbers.
  2. Use MLP widths on the order of 16-128 neurons wide. The underlying algorithm for full MLP fusion only works with tiny networks (see section 4 of this paper). If the goal is to run MLPs on the order of thousands of neurons wide, I would not recommend using this framework.
  3. Use the CUDA/C++ API. Interstingly (and I didn't yet dig deep enough to have a good explanation), going through the PyTorch bindings introduces a roughly ~2x overhead in use-cases that match points (1) and (2), such as the bundled example code for learning an image. This overhead is likely negligible for the kind of MLPs used in your benchmark, though, so no qualms with the Python API there.

Cheers!

SeanNaren commented 2 years ago

All makes sense was really easy to work with the framework, fantastic work @Tom94!

I think the intuition that @blefaudeux was that fusing the MLP inside the transformer blocks would give us a pretty hefty speed up. would you agree here @Tom94? I still think it's worth investigating

blefaudeux commented 2 years ago

All makes sense was really easy to work with the framework, fantastic work @Tom94!

I think the intuition that @blefaudeux was that fusing the MLP inside the transformer blocks would give us a pretty hefty speed up. would you agree here @Tom94? I still think it's worth investigating

I was not sure about what tinyNN was capable of handling, basically "fusing the MLP" is easier said than done (I tried some of it myself with triton, definitely a mouthful even if you can fuse some steps decently :)). As far as I understood I overestimated what size tinyNN was capable of handling, may not be a great fit for you @SeanNaren in the end for the embedding sizes that you're targeting (>> 128) (@Tom94 will obviously know more)

SeanNaren commented 2 years ago

was great to try regardless, I learnt something new :)

SeanNaren commented 2 years ago

Closing this for now, as I think it's definitive!