NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.16k stars 1.35k forks source link

fixup concats for grouped convolution #1811

Open techshoww opened 2 weeks ago

techshoww commented 2 weeks ago

The function "fixup_concats" in file "apex\contrib\sparsity\permutation_lib.py " does not take grouped convolution into account, it will get a wrong "sibling_group_C_params" for a node which is grouped convolution or which has grouped convolution siblings. So this pull request adds a line code to deal with this issue. Here i will explained the issue in detail, image As shown in above figure, node A and B are siblings, we assume their sibling_group_C_params is 16. And P1 and P2's output channels are 32. When the ASP does "fixup_concats" for node A, it may get children_GCD_param 32, then node A and node B's sibling_group_C_params will be update to 32. When the ASP does permutation for node B, the exception "AssertionError: sibling B's weights' C=16 must be even multiple of the sibling group's C parameter 32" will be raised.