NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k stars 1.36k forks source link

Tkurth/sgbn fixes #1685

Closed azrael417 closed 1 year ago

azrael417 commented 1 year ago

This PR fixes the single node group batch norm in APEX to work with cuda 12.2 and RTC.

Aidyn-A commented 1 year ago

@eqy, @rmhaskarnvidia please review this PR and/or suggest someone to review. I will also take a look, but I am not familiar with cudnn-frontend.

eqy commented 1 year ago

@azrael417 we can defer addressing the issues I brought up to a later PR if @crcrpar is content to merge the fix given the urgency

crcrpar commented 1 year ago

rel: https://github.com/NVIDIA/apex/issues/1689