triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.94k stars 1.57k forks source link

Implement BatchNorm in triton #900

Open Jack47 opened 1 year ago

Jack47 commented 1 year ago

batchnorm(bn) is very popular in CV, almost every conv op will be followed by bn. I see layernorm in triton achieved best HBM bandwidth. So I'm curious about implement batchnorm in triton.

My questions:

  1. is there any chance that triton implemented bn faster than pytorch?
  2. may you guys give more details on layernorm's triton implementation? about why it achieved so amazing bandwidth
ptillet commented 1 year ago

For layer norm, you reduce along the contiguous dimension, so it's possible to fuse everything. For batch norm, things get a little tricky. Have never looked into it, so I can't tell if it's possible to beat pytorch. But it should be possible to at least match it.

Jack47 commented 1 year ago

ok, thanks! I will look at it, will update here if I have some conclusion.

BobMcDear commented 6 months ago

In case anyone is still interested in this issue, I have written a Triton implementation of batch normalization here. It supports activation fusion as well as residual addition and can be substantially faster than its PyTorch counterpart (minus torch.compile, that is).