NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

[GroupNorm] Skip GroupNorm tests on A16, A2 etc., #1858

Closed eqy closed 2 days ago

eqy commented 2 days ago

The grid dim calculations (among others) seem to make assumptions about the minimal number of SMs on a device e.g.,

grid.y = std::min(max_blocks_per_grid / blocks_per_slice, params.groups * params.n);

This assumption causes grid.y to be set to 0 on certain devices with low SM count (10) e.g., A16, A2. Disabling this test for now.

CC @alpha0422 @crcrpar @ptrblck