pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.69k stars 22.81k forks source link

torch.nn.functional.gumbel_softmax reporting NaNs in Pytorch 2.0 #103459

Closed bryant0918 closed 1 year ago

bryant0918 commented 1 year ago

🐛 Describe the bug

I am now getting the same issue in pytorch 2.0

Code to reproduce:

import torch
from torch.nn.functional import gumbel_softmax

def test_gumbel():
    print(torch.__version__)
    torch.manual_seed(1234)
    randy = torch.rand((5,8192,96,96)) *2 -1
    randy *= 255

    print(randy.dtype, randy.shape, torch.min(randy), torch.max(randy))
    out = gumbel_softmax(randy)
    print(out.dtype, out.shape, torch.min(out), torch.max(out))

if __name__ == "__main__":
    test_gumbel()

Output:

2.0.1+cu118
torch.float32 torch.Size([5, 8192, 96, 96]) tensor(-255.) tensor(255.0000)
torch.float32 torch.Size([5, 8192, 96, 96]) tensor(nan) tensor(nan)

Versions

python 3.9.7 pytorch 2.0.1+cu118

--2023-06-12 14:57:28-- https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py Resolving raw.githubusercontent.com... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com|185.199.108.133|:443... connected. OpenSSL: error:140770FC:SSL routines:SSL23_GET_SERVER_HELLO:unknown protocol Unable to establish SSL connection.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

bryant0918 commented 1 year ago

A simple work around for now is to change the gumbel_softmax function in torch.nn.functional.py to use the unused eps paramter.

# gumbels = (
#     -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
# )  # ~Gumbel(0,1)
gumbels = torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_() + eps
gumbels = (-gumbels.log())

However exponential_() should not return 0's.

mingfeima commented 1 year ago

@min-jean-cho did your last fix on exponential_ solve this issue ?

min-jean-cho commented 1 year ago

Hi @bryant0918, @mingfeima, yes, the issue is resolved by https://github.com/pytorch/pytorch/pull/101720. Also, a duplicate of https://github.com/pytorch/pytorch/issues/101620.

albanD commented 1 year ago

Thanks! Should we close both of these issues now since they're resolved in main?

min-jean-cho commented 1 year ago

Thanks @albanD, yes, closing the issues as resolved by https://github.com/pytorch/pytorch/pull/101720.