Closed yonigottesman closed 2 years ago
That sounds like something we definitely should support. I think we can add it to the build_
The only concern I have is I'm not sure what impact this will have if we load the imagenet weights. I've been freezing the batch norm layers in that case, so maybe it doesn't matter here as I'm assuming we're training from scratch in the distributed case.
WDYT?
Im not sure how easy it is to search and replace a layer in tf. It can be done, but maybe tricky. Trivially just running on the layers like this can be hard with skip connections.
x = model.input
for l in layers:
if l == batch_norm:
x = sync_batch_norm(x)
else:
x = l(x)
Ill give it a try but If it wont work I might need to add the code for resnet50 like the resnet18 and not take from keras.applications.
Regarding the imagnet weights this is indeed an issue. loading all the weights and then restarting the bn layers is a bad idea. It should be explicit that you cannot use syncedbn and load imagenet weights. MAYBE it is possible to load the weights from regular bn into synced version this will be the best ill look into this too.
Also, can you re-assign me I deleted myself from the issue by mistake :-)
I added you back as the issue assignee.
I agree, it might be tricky to do the search and replace. On the other hand, it would be good if we could avoid duplicating the tf.keras.applications if possible. We didn't have a choice for ResNet18, but it would be good to use the the applications models if possible.
I also like you approach regarding the the imagenet weights. I think we can just make it so that synced bn is only trainable from scratch.
When using distributed strategies (a must in these models) the simsiam, simclr and barlow twins all use synced batch norm across devices. simclr uses
tf.keras.layers.experimental.SyncBatchNormalization
and simsiam,barlow use pytorchnn.SyncBatchNorm.convert_sync_batchnorm
.we should either rewrite the models with SyncBatchNormalization (will have to implement the resnet50) or come up with a tf function
convert_sync_batchnorm
which replaces the bn layers.What do you think? (I would like to work on this issue)