pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.73k stars 22.82k forks source link

Unstable buggy calculation. #68144

Open davidleon opened 3 years ago

davidleon commented 3 years ago
import torch
from torch.autograd import Variable
import torch.optim as optim
l=torch.Tensor([1.01, 1.01, 3, 3])

l=Variable(l.float(), requires_grad = True)
optimizer = optim.Adam([l], lr=0.01)
optimizer.zero_grad()

while True:
    l1 = l/(torch.cat([l.mean().unsqueeze(0),l]).cummax(0).values[1:])
    t = l1>=1
    t1 = l1 <1
    diffstate = l1[0:-1]==l1[1:]
    tsum=t.cumsum(0).float()
    substate = tsum - ((tsum*torch.cat([diffstate,torch.Tensor(1)])*t1).cummax(0)).values
    t1sum=t1.cumsum(0).float()
    substate1 = ((t1sum*torch.cat([diffstate,torch.Tensor(1)])*t).cummax(0)).values - t1sum 
    c=(l1-1)*(1+torch.tanh(substate/10+substate1/10))
    loss=-c.sum()
    loss.backward()
    optimizer.step()
    print(loss,l1)
>>> optimizer.step()
>>> print(loss,l1)
tensor(nan, grad_fn=<NegBackward>) tensor([0.5135, 0.5135, 1.0000, 1.0000], grad_fn=<DivBackward0>)
>>> substate
tensor([0., 0., 1., 2.])
>>> substate1
tensor([-1., -2.,  0., nan])
>>> substate1 = ((t1sum*torch.cat([diffstate,torch.Tensor(1)])*t).cummax(0)).values - t1sum
>>> substate1
tensor([-1., -2.,  0., nan])    <--- wrong
>>> ((t1sum*torch.cat([diffstate,torch.Tensor(1)])*t).cummax(0)).values
tensor([0., 0., 2., 2.])
>>> t1sum
tensor([1., 2., 2., 2.])
>>> ((t1sum*torch.cat([diffstate,torch.Tensor(1)])*t).cummax(0)).values - t1sum
tensor([-1., -2.,  0.,  0.])      <--- this result is correct 

the above code would get nan loss and especially substate1 would be tensor([-1., -2., 0., nan]) after two iteration. however if calculate substate1 mannually, the result can be evaluated correctly to tensor([-1., -2., 0., 0.]). The last element should be evaluated to 0 not nan.

cc @vincentqb @jbschlosser @albanD

albanD commented 3 years ago

Running your code sample on colab properly converges to 0. Do you have more details on your configuration and pytorch version?

davidleon commented 3 years ago

python 3.8.6 torch '1.7.0'