rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

real is not implemented for tensors with non-complex dtypes #225

Open PengleiGao opened 1 year ago

PengleiGao commented 1 year ago

Hi, the dtype of the input is float32. there is no real of the input. the problem raises here "t = t.real.to(y.abs().dtype)" How to solve this problem? 1678256119272

DrKarlWu commented 1 year ago

这个的主要原因是因为t应该是是一个复数张量(最少二维),但是在神经常微分方程,t一般是个实数张量(一维),所以只需要把这个t变成二维复数张量即可,虚部当然是一个0了。

 if t.numel() == 1:
     mid = [t.item(),0.0]
     t = torch.tensor(mid)
 t = torch.view_as_complex(t)
 t = t.real.to(y.abs().dtype)
rtqichen commented 1 year ago

You should upgrade your PyTorch. The more recent versions (I think versions >= 1.6?) have .real implemented for non-complex tensor types. In the meantime, I'll work on a fix for the older PyTorch versions.