Closed Re3write closed 5 years ago
def convert_model(module):
if isinstance(module, torch.nn.DataParallel):
mod = module.module
mod = convert_model(mod)
mod = DataParallelWithCallback(mod,device_ids=[0,1,2,3]).cuda()
return mod
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
if isinstance(module, pth_module):
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine: = =
return mod
for name, child in module.named_children():
module.add_module(name, convert_model(child))
return module
this is our version to overcome the problem
Thanks for reporting!
I just tested the current version myself:
from torchvision import models
from sync_batchnorm import convert_model
m = models.resnet18(pretrained=True)
m = convert_model(m)
The codes above run successfully and give the expected output network. Could you please specify the case where our current vision fails? That will be deeply appreciated! Thanks!
Traceback (most recent call last):
File "", line 7, in
@Re3write Can you make sure that you have this line in your file?
It looks to me that you somehow deleted this line?
@vacancy sorry, the code we use dont has the line , maybe we accidentally deleted it.
No worries. Best luck!
there are some problme when i run the example code about convert_model,the variable ‘mod’ was not assigned,it seems something wrong about the recursion