A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k
stars
1.36k
forks
source link
`apex.contrib.group_norm` would better have an import guard of `group_norm_cuda` #1701
Open
crcrpar opened 12 months ago
https://github.com/NVIDIA/apex/blob/50ac8425403b98147cbb66aea9a2a27dd3fe7673/apex/contrib/group_norm/group_norm.py#L21
Some contrib modules don't have one though e.g. https://github.com/NVIDIA/apex/blob/50ac8425403b98147cbb66aea9a2a27dd3fe7673/apex/contrib/layer_norm/layer_norm.py#L5
cc @ptrblck @xwang233