Closed bfialkoff closed 3 years ago
This sounds really good! The CUDA checks completely slipped my mind and since I'm working on a Mac, I did not bother checking with CUDA. I'll replace that line with your solution, seems perfect. Thanks a ton!
Happy to contribute!
On Mon, Feb 15, 2021, 16:16 Vaibhav Balloli notifications@github.com wrote:
This sounds really good! The CUDA checks completely slipped my mind and since I'm working on a Mac, I did not bother checking with CUDA. I'll replace that line with your solution, seems perfect. Thanks a ton!
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/vballoli/nfnets-pytorch/issues/3#issuecomment-779250615, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJO6FM7EZEN7M4SUN47IFT3S7EUEZANCNFSM4XUYGCKA .
perfect! thanks for your contribution, i'm also facing this issue now
Fixed https://github.com/vballoli/nfnets-pytorch/commit/55f9e2454c173314b7e3faab3934ea06d329e2fd ! Thanks again @bfialkoff !
Describe the bug A clear and concise description of what the bug is.
To Reproduce Steps to reproduce the behavior:
torch.device('cuda)
Call model.forward()
Expected behavior Regular output
Screenshots If applicable, add screenshots to help explain your problem.
See here
My Solution:
Its hacky obviously but it works. Simply replace https://github.com/vballoli/nfnets-pytorch/blob/867860eebffcc70fb87a389d770cfd4a73c6b30c/nfnets/base.py#L22 with
scale = torch.rsqrt(torch.max(var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)