NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k stars 1.36k forks source link

`apex.contrib.group_norm` would better have an import guard of `group_norm_cuda` #1701

Open crcrpar opened 12 months ago

crcrpar commented 12 months ago

https://github.com/NVIDIA/apex/blob/50ac8425403b98147cbb66aea9a2a27dd3fe7673/apex/contrib/group_norm/group_norm.py#L21

Some contrib modules don't have one though e.g. https://github.com/NVIDIA/apex/blob/50ac8425403b98147cbb66aea9a2a27dd3fe7673/apex/contrib/layer_norm/layer_norm.py#L5

cc @ptrblck @xwang233

crcrpar commented 12 months ago

Alternative: have a check in test and skip accordingly