rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

odeint with torch.ComplexFloatTensor arguments #219

Closed sriharikrishna closed 1 year ago

sriharikrishna commented 1 year ago

It does not seem possible to use odeint with complex arguments. Is this a feature that will be supported in the future? Here is my MWE:

import torch
from torchdiffeq import odeint as odeint

t_i = 0.
t_f = 2.
t = torch.linspace(t_i, t_f, 10)

y0 = torch.tensor([1.0, 9.], dtype=torch.complex64)
A = torch.tensor([[0, 1.0], [- 100.0, 0]], dtype=torch.complex64)

def ode_fn(t, y):
    return torch.mv(A, y)

true_y = odeint(ode_fn, y0, t, method='dopri5')

Output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-10-38fdf79c3601>](https://localhost:8080/#) in <module>
     12     return torch.mv(A, y)
     13 
---> 14 true_y = odeint(ode_fn, y0, t, method='dopri5')

[/usr/local/lib/python3.8/dist-packages/torchdiffeq/_impl/odeint.py](https://localhost:8080/#) in odeint(func, y0, t, rtol, atol, method, options, event_fn)
     70     """
     71 
---> 72     shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
     73 
     74     solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options)

[/usr/local/lib/python3.8/dist-packages/torchdiffeq/_impl/misc.py](https://localhost:8080/#) in _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
    211         if event_fn is not None:
    212             event_fn = _TupleInputOnlyFunc(event_fn, shapes)
--> 213     _assert_floating('y0', y0)
    214 
    215     # Normalise method and options

[/usr/local/lib/python3.8/dist-packages/torchdiffeq/_impl/misc.py](https://localhost:8080/#) in _assert_floating(name, t)
    104 def _assert_floating(name, t):
    105     if not torch.is_floating_point(t):
--> 106         raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type()))
    107 
    108 

TypeError: `y0` must be a floating point Tensor but is a torch.ComplexFloatTensor
rtqichen commented 1 year ago

I've just enabled complex types on the latest commit. Give it a try and if there are issues, let me know.

To install the latest commit:

pip install git+https://github.com/rtqichen/torchdiffeq
sriharikrishna commented 1 year ago

Works for me now. Thanks!

iranroman commented 1 year ago

I still see the same issue with torchdiffeq 0.2.3

dssd96 commented 1 year ago

For me this issue is still present even when installing via pip with pip install git+https://github.com/rtqichen/torchdiffeq

rtqichen commented 1 year ago

Do you have an example that can reproduce the error?

dssd96 commented 1 year ago

Nevermind, this seems to be a problem with pip. The version was never actually updated. I had to explicitly uninstall the package first and then install the git version.