vballoli / nfnets-pytorch

NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch. Find explanation at tourdeml.github.io/blog/
https://nfnets-pytorch.readthedocs.io/en/latest/
MIT License
343 stars 29 forks source link

TypeError: replace_conv() takes 1 positional argument but 2 were given #2

Closed sainatarajan closed 3 years ago

sainatarajan commented 3 years ago

Describe the bug Hello, thank you for the repo. I'm having an issue with the replace_conv().

To Reproduce

model = torchvision.models.resnet101(pretrained=False)
model.fc = nn.Linear(in_features=2048, out_features=8, bias=True)
replace_conv(model)

model = model.to(device)

Screenshots Here is a screenshot of the stacktrace image

vballoli commented 3 years ago

Oh yeah sorry about that, that must have been a last minute mistake. Simply replacing it by replace_conv(ch) should work. The fix should be updated soon. Note that the batchnorm layers aren't being removed as of now, so the replace_conv isn't fully functional.

vballoli commented 3 years ago

This commit https://github.com/vballoli/nfnets-pytorch/commit/24d9398d30f0c5f04838667cfc176cab1a5f136f fixes the issue and replaces the BatchNorm2d with identity. Feel free to re-open if you find more issues.