NVlabs / tiny-cuda-nn

Lightning fast C++/CUDA neural network framework
Other
3.73k stars 453 forks source link

How big can you make a fully fused MLP while retaining performance benefits? #38

Open oscarknagg opened 2 years ago

oscarknagg commented 2 years ago

The speedup in this repo relies on getting the memory traffic close to the chip - in caches/registers etc. This is going to stop working if an MLP is sufficiently large, but I'm unclear where the boundary is.

Does anyone know the answers to these questions:

  1. How big can you make an MLP while retaining the performance benefits? (Has anyone tested this?)
  2. Can you "trade off" a smaller batch size for a larger MP and still keep the benefits?
  3. Would using more powerful hardware (e.g. an A100 which has 40MB L2 cache over an RTX3090 which 6MB L2 cache) expand this performance window?

I could potentially help out with testing (3)

Tom94 commented 2 years ago

Hi there!

  1. It's a continuum, which I would classify as follows: a. Let's use CUTLASS's matrix multiplication routines (implemented in CutlassMLP) as a baseline, since these avoid unrelated overheads of Python frameworks. b. Compared to that baseline, I've observed: significant speedups for 64-wide and smaller MLPs, moderate speed-ups for 128-wide MLPs, and no speedup for 256-wide MLPs. RTX 3090. I've hand-tuned the low-level kernel configurations for each of them, so am reasonably confident in this.
  2. Unfortunately not with the structure of computation that the current implementation exploits.
  3. Again, unfortunately no. The L2 cache is shared across multiple SMs and thus equally benefits traditional (non-fused) matmuls as well as the fully fused approach. What would help is an increased register file, L1 cache, and shared memory. To fully exploid these, a few of the low-level kernel parameters in fully_fused_mlp.cu need to be tuned to whichever sizes are available.
oscarknagg commented 2 years ago

For reference, here are the specs of an A100 vs a 3090:

These numbers are quite similar on a per-SM basis, although the A100 has significantly more SMs. Do you think this would make much difference? (Provided kernel parameters are tuned appropriately)