Describe the Bug
When feature_num is not pow-of-two, apex.parallel.SyncBatchNorm will produce wrong result.
I test it step by step, and found it produce wrong mean and var when feature_h and feature_w is large enough (see the minimal reproduce code below).
Minimal Steps/Code to Reproduce the Bug
## when feature h, feature w is small, it produce correct result.
```python
import torch
import syncbn
feature_size = 65 # not pow-of-two
feature_h = 10
feature_w = 10
# when feature_h, feature_w is small, it produce correct mean and var
input = torch.rand(1, feature_size, feature_h, feature_w)
input_clast = input.permute([0,2,3,1]).contiguous()
var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False)
mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast)
torch.allclose(mean, mean_apex) # it is true
```
## when feature h, feature w is large, it produce wrong result.
```python
import torch
import syncbn
feature_size = 65 # not pow-of-two
feature_h = 100
feature_w = 100
# when feature_h, feature_w is large, it produce wrong mean and var
input = torch.rand(1, feature_size, feature_h, feature_w)
input_clast = input.permute([0,2,3,1]).contiguous()
var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False)
mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast)
torch.allclose(mean, mean_apex) # it is False
```
**Expected Behavior**
Describe the Bug When feature_num is not pow-of-two, apex.parallel.SyncBatchNorm will produce wrong result.
I test it step by step, and found it produce wrong mean and var when feature_h and feature_w is large enough (see the minimal reproduce code below).
Minimal Steps/Code to Reproduce the Bug
## when feature h, feature w is small, it produce correct result. ```python import torch import syncbn feature_size = 65 # not pow-of-two feature_h = 10 feature_w = 10 # when feature_h, feature_w is small, it produce correct mean and var input = torch.rand(1, feature_size, feature_h, feature_w) input_clast = input.permute([0,2,3,1]).contiguous() var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False) mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast) torch.allclose(mean, mean_apex) # it is true ``` ## when feature h, feature w is large, it produce wrong result. ```python import torch import syncbn feature_size = 65 # not pow-of-two feature_h = 100 feature_w = 100 # when feature_h, feature_w is large, it produce wrong mean and var input = torch.rand(1, feature_size, feature_h, feature_w) input_clast = input.permute([0,2,3,1]).contiguous() var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False) mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast) torch.allclose(mean, mean_apex) # it is False ``` **Expected Behavior**Environment
ngc_23.11