Open PengleiGao opened 1 year ago
这个的主要原因是因为t应该是是一个复数张量(最少二维),但是在神经常微分方程,t一般是个实数张量(一维),所以只需要把这个t变成二维复数张量即可,虚部当然是一个0了。
if t.numel() == 1:
mid = [t.item(),0.0]
t = torch.tensor(mid)
t = torch.view_as_complex(t)
t = t.real.to(y.abs().dtype)
You should upgrade your PyTorch. The more recent versions (I think versions >= 1.6?) have .real implemented for non-complex tensor types. In the meantime, I'll work on a fix for the older PyTorch versions.
Hi, the dtype of the input is float32. there is no real of the input. the problem raises here "t = t.real.to(y.abs().dtype)" How to solve this problem?