pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

Revisiting SyncBatchNorm #2843

Closed evanatyourservice closed 3 years ago

evanatyourservice commented 3 years ago

🚀 Feature

Hello all,

I am wondering if there is any interest in supporting Pytorch's SyncBatchNorm, also mentioned in issue #2223. This is a commonly used tool nowadays.

Alternatives

This is the alternative I just came up with:

def reduce_bn_stats(net):
    for m in net.modules():
        if isinstance(m, nn.BatchNorm1d):
            xm.all_reduce(xm.REDUCE_SUM, m.running_mean, scale=1.0 / xm.xrt_world_size())
            xm.all_reduce(xm.REDUCE_SUM, m.running_var, scale=1.0 / xm.xrt_world_size())

but can anyone think of a more efficient way to do this for now?

JackCaoG commented 3 years ago

@evanatyourservice thanks for reporting! I think we could lower this op but under the hood we will use using the existing cross_replica_reduce methods in https://github.com/pytorch/xla/blob/08ae1044c2a7e314895f9946104cbe399e096515/torch_xla/csrc/cross_replica_reduces.cpp to implement it. @jysohn23 wdyt?

jysohn23 commented 3 years ago

Yeah I'm pretty sure that's what tf2xla bridge also does for SyncBN.

evanatyourservice commented 3 years ago

Awesome, I'm not that versed so I don't think I can help, but was just curious!

rwightman commented 3 years ago

Definitely of value for some nets in the large image size categories where batch sizes need to be small and stability is problematic without synbn, thinking of many situations for object detection/segmentation/etc.

@evanatyourservice what you're doing there actually works quite well to keep running stats in sync between distributed replicas when not using syncbn. It's cheap because it's done once per epoch but doesn't help stability concerns. It does prevent your replicas running stats from drifting though. The variance isn't technically correct but good enough.

Syncbn is every step and operates on the mean and sqr mean batch statistics and calcs the variance properly.

evanatyourservice commented 3 years ago

@rwightman Oh I see, seems I don't know enough about how the variance is calculated to do this properly, I'll have to look into it. I've been running that after every backward pass and it doesn't seem to add much overhead at all.

rwightman commented 3 years ago

@evanatyourservice yeah, speed of your approach isn't an issue

Here are two reasonble refs that are fairly clear... a TF TPU oriented impl and the PyTorch APEX Amp impl. The PyTorch native one isn't as easy to read and is split across multiple files...

https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/utils.py#L118 (this overrides the base BatchNorm so need to look at base class for full impl)

https://github.com/NVIDIA/apex/blob/master/apex/parallel/sync_batchnorm.py#L68

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.