Open Jack47 opened 2 years ago
For layer norm, you reduce along the contiguous dimension, so it's possible to fuse everything. For batch norm, things get a little tricky. Have never looked into it, so I can't tell if it's possible to beat pytorch. But it should be possible to at least match it.
ok, thanks! I will look at it, will update here if I have some conclusion.
In case anyone is still interested in this issue, I have written a Triton implementation of batch normalization here. It supports activation fusion as well as residual addition and can be substantially faster than its PyTorch counterpart (minus torch.compile
, that is).
batchnorm(bn) is very popular in CV, almost every conv op will be followed by bn. I see layernorm in triton achieved best HBM bandwidth. So I'm curious about implement batchnorm in triton.
My questions: