Closed evanatyourservice closed 3 years ago
@evanatyourservice thanks for reporting! I think we could lower this op but under the hood we will use using the existing cross_replica_reduce methods in https://github.com/pytorch/xla/blob/08ae1044c2a7e314895f9946104cbe399e096515/torch_xla/csrc/cross_replica_reduces.cpp to implement it. @jysohn23 wdyt?
Yeah I'm pretty sure that's what tf2xla bridge also does for SyncBN.
Awesome, I'm not that versed so I don't think I can help, but was just curious!
Definitely of value for some nets in the large image size categories where batch sizes need to be small and stability is problematic without synbn, thinking of many situations for object detection/segmentation/etc.
@evanatyourservice what you're doing there actually works quite well to keep running stats in sync between distributed replicas when not using syncbn. It's cheap because it's done once per epoch but doesn't help stability concerns. It does prevent your replicas running stats from drifting though. The variance isn't technically correct but good enough.
Syncbn is every step and operates on the mean and sqr mean batch statistics and calcs the variance properly.
@rwightman Oh I see, seems I don't know enough about how the variance is calculated to do this properly, I'll have to look into it. I've been running that after every backward pass and it doesn't seem to add much overhead at all.
@evanatyourservice yeah, speed of your approach isn't an issue
Here are two reasonble refs that are fairly clear... a TF TPU oriented impl and the PyTorch APEX Amp impl. The PyTorch native one isn't as easy to read and is split across multiple files...
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/utils.py#L118 (this overrides the base BatchNorm so need to look at base class for full impl)
https://github.com/NVIDIA/apex/blob/master/apex/parallel/sync_batchnorm.py#L68
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
🚀 Feature
Hello all,
I am wondering if there is any interest in supporting Pytorch's SyncBatchNorm, also mentioned in issue #2223. This is a commonly used tool nowadays.
Alternatives
This is the alternative I just came up with:
but can anyone think of a more efficient way to do this for now?