NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.33k stars 1.39k forks source link

syncbn with "channel_last=True" produce wrong result when feature_num is not pow-of-two #1768

Open Zehaos opened 8 months ago

Zehaos commented 8 months ago

Describe the Bug When feature_num is not pow-of-two, apex.parallel.SyncBatchNorm will produce wrong result.

I test it step by step, and found it produce wrong mean and var when feature_h and feature_w is large enough (see the minimal reproduce code below).

Minimal Steps/Code to Reproduce the Bug

## when feature h, feature w is small, it produce correct result. ```python import torch import syncbn feature_size = 65 # not pow-of-two feature_h = 10 feature_w = 10 # when feature_h, feature_w is small, it produce correct mean and var input = torch.rand(1, feature_size, feature_h, feature_w) input_clast = input.permute([0,2,3,1]).contiguous() var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False) mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast) torch.allclose(mean, mean_apex) # it is true ``` ## when feature h, feature w is large, it produce wrong result. ```python import torch import syncbn feature_size = 65 # not pow-of-two feature_h = 100 feature_w = 100 # when feature_h, feature_w is large, it produce wrong mean and var input = torch.rand(1, feature_size, feature_h, feature_w) input_clast = input.permute([0,2,3,1]).contiguous() var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False) mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast) torch.allclose(mean, mean_apex) # it is False ``` **Expected Behavior**

Environment

ngc_23.11

Zehaos commented 8 months ago

cc @jjsjann123