NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k stars 1.36k forks source link

Fast CUDA NHWC Group Norm #1695

Closed alpha0422 closed 12 months ago

alpha0422 commented 1 year ago

This PR creates efficient CUDA NHWC Group Norm implementation, which is required for a lot of diffusion models like Stable Diffusion. We tested with Stable Diffusion and Imagen, it improves the performance a lot.

The CUDA kernels were originally developed by Julien Demouth and Nikita Korobov.

alpha0422 commented 1 year ago

Hi @ptrblck , could you take a review and merge the PR? This is required by various diffusion models, thanks!

alpha0422 commented 1 year ago

I just added a couple of tests to apex/contrib/test/group_norm/test_group_norm.py.

alpha0422 commented 1 year ago

Done, thanks for your suggestions.

felixdae commented 6 months ago

great work! i have a question. if i want to support more channels, what should i do

alpha0422 commented 6 months ago

great work! i have a question. if i want to support more channels, what should i do

You can create a PR if you want to contribute and support more channels for this GroupNorm implementation. Meanwhile, this function will be moved into cuDNN, and more channels will be supported there.