NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

fixup concats for grouped convolution #1811

Open techshoww opened 5 months ago

techshoww commented 5 months ago

The function "fixup_concats" in file "apex\contrib\sparsity\permutation_lib.py " does not take grouped convolution into account, it will get a wrong "sibling_group_C_params" for a node which is grouped convolution or which has grouped convolution siblings. So this pull request adds a line code to deal with this issue. Here i will explained the issue in detail, image As shown in above figure, node A and B are siblings, we assume their sibling_group_C_params is 16. And P1 and P2's output channels are 32. When the ASP does "fixup_concats" for node A, it may get children_GCD_param 32, then node A and node B's sibling_group_C_params will be update to 32. When the ASP does permutation for node B, the exception "AssertionError: sibling B's weights' C=16 must be even multiple of the sibling group's C parameter 32" will be raised.