Open StephenHogg opened 2 years ago
Came to the conclusion that it's likely because of forward()
getting wrapped to include a time argument. I gave NeuralODE
a network with the time arg already added to avoid this and came across this error instead:
RuntimeError:
Unknown type name 'Iterable':
File "/home/shogg/.pyenv/versions/3.8.12/envs/mldi/lib/python3.8/site-packages/torchdyn/core/neuralde.py", line 92
def forward(self, x:Tensor, t_span:Tensor=None, save_at:Iterable=(), args={}):
~~~~~~~~ <--- HERE
x, t_span = self._prep_integration(x, t_span)
t_eval, sol = super().forward(x, t_span, save_at, args)
The weird thing about this is that I can see that Iterable
is definitely imported at the top of the file
I did a bit more looking and can see that torchscript is not happy with Iterable
as an annotation. Here's the list of allowable type hints, is it possible to change to one of these? Happy to write a PR if so
https://pytorch.org/docs/stable/jit_language_reference.html#supported-type
Update again: removing that type annotation revealed another annotation to be incorrect and also revealed a problem in the way forward
is called. Seems like you guys would probably want to do a refactor to make torchscripting safe with this package. Let me know if you're interested in doing it and I'd be happy to work with you on it, it would be great to be able to serialise these models.
Thanks for looking into this! I've started a refactor in this branch.
I fixed some of the typing inconsistencies and managed to push through to:
RuntimeError:
'Tensor' object has no attribute or method 'forward'.:
File "/home/stefano/michael/diffeqml/torchdyn/torchdyn/core/neuralde.py", line 94
def forward(self, x:Tensor, t_span:Tensor=None, save_at:Tensor=()):
x, t_span = self._prep_integration(x, t_span)
t_eval, sol = super().forward(x, t_span, save_at)
~~~~~~~~~~~~~ <--- HERE
if self.return_t_eval: return t_eval, sol
else: return sol
Is that the error you observed?
Hi @Zymrael - yes, that one is a problem because what you would do there is:
super(ODEProblem, self).forward(x, t_span, save_at)
but torchscript isn't happy with types being passed as parameters and errors if you do that fix. This implies that some of the inheritance the current factoring of the code relies on would potentially need to be addressed - this isn't necessarily a bad thing if it reduces the amount of redirection in the code a little anyway. Hope this helps, happy to keep chatting on this one if it helps.
Describe the bug
Torchscript is unable to script the
NeuralODE
class due to a function being redefined. This is a problem because there is control flow present in the code that tracing would not necessarily respect, implying that alternative would produce an incorrect output.Step to Reproduce
Minimal working example:
produces the following errors:
The first time you run it, the error is:
on subsequent attempts, the error is:
Note that changing the solvers doesn't appear to change anything.
Expected behavior
Torchscript should not break
I'm using the latest available version of torchdyn from pip and
torch==1.11.0+cu102
, would be very grateful for your advice as to how to torchscript this safely.