vballoli / nfnets-pytorch

NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch. Find explanation at tourdeml.github.io/blog/
https://nfnets-pytorch.readthedocs.io/en/latest/
MIT License
343 stars 29 forks source link

torch.max doesn't check for tensors being on different devices. #3

Closed bfialkoff closed 3 years ago

bfialkoff commented 3 years ago

Describe the bug A clear and concise description of what the bug is.

To Reproduce Steps to reproduce the behavior:

  1. Go to example and instantiate a resnet18
  2. Send model to torch.device('cuda)
  3. Define a tensor on the gpu
  4. Call model.forward()
  5. RuntimeError: iter.device(arg).is_cuda() INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Loops.cuh":94, please report a bug to PyTorch.

Expected behavior Regular output

Screenshots If applicable, add screenshots to help explain your problem.

See here

My Solution:

Its hacky obviously but it works. Simply replace https://github.com/vballoli/nfnets-pytorch/blob/867860eebffcc70fb87a389d770cfd4a73c6b30c/nfnets/base.py#L22 with scale = torch.rsqrt(torch.max(var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)

vballoli commented 3 years ago

This sounds really good! The CUDA checks completely slipped my mind and since I'm working on a Mac, I did not bother checking with CUDA. I'll replace that line with your solution, seems perfect. Thanks a ton!

bfialkoff commented 3 years ago

Happy to contribute!

On Mon, Feb 15, 2021, 16:16 Vaibhav Balloli notifications@github.com wrote:

This sounds really good! The CUDA checks completely slipped my mind and since I'm working on a Mac, I did not bother checking with CUDA. I'll replace that line with your solution, seems perfect. Thanks a ton!

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/vballoli/nfnets-pytorch/issues/3#issuecomment-779250615, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJO6FM7EZEN7M4SUN47IFT3S7EUEZANCNFSM4XUYGCKA .

Hansxsourse commented 3 years ago

perfect! thanks for your contribution, i'm also facing this issue now

vballoli commented 3 years ago

Fixed https://github.com/vballoli/nfnets-pytorch/commit/55f9e2454c173314b7e3faab3934ea06d329e2fd ! Thanks again @bfialkoff !