DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 130 forks source link

Torchscripting not possible with `NeuralODE` due to function redefinition #163

Open StephenHogg opened 2 years ago

StephenHogg commented 2 years ago

Describe the bug

Torchscript is unable to script the NeuralODE class due to a function being redefined. This is a problem because there is control flow present in the code that tracing would not necessarily respect, implying that alternative would produce an incorrect output.

Step to Reproduce

Minimal working example:

f = nn.Sequential(
        nn.Linear(2, 16),
        nn.Tanh(),
        nn.Linear(16, 2)
    )

model = NeuralODE(f, solver='tsit5', solver_adjoint='dopri5')
out = torch.jit.script(model)

produces the following errors:

The first time you run it, the error is:

forward(__torch__.torch.nn.modules.container.Sequential self, Tensor input) -> (Tensor):
Expected at most 2 arguments but found 3 positional arguments.
:
  File "/home/shogg/.pyenv/versions/3.8.12/envs/mldi/lib/python3.8/site-packages/torchdyn/core/defunc.py", line 32
    def forward(self, t:Tensor, x:Tensor) -> Tensor:
        self.nfe += 1
        if self.has_time_arg: return self.vf(t, x)
                                     ~~~~~~~ <--- HERE
        else: return self.vf(x)

on subsequent attempts, the error is:

RuntimeError: Can't redefine method: forward on class: __torch__.torchdyn.core.defunc.DEFuncBase (of Python compilation unit at: 0x560bca2c9a70)

Note that changing the solvers doesn't appear to change anything.

Expected behavior

Torchscript should not break

I'm using the latest available version of torchdyn from pip and torch==1.11.0+cu102, would be very grateful for your advice as to how to torchscript this safely.

StephenHogg commented 2 years ago

Came to the conclusion that it's likely because of forward() getting wrapped to include a time argument. I gave NeuralODE a network with the time arg already added to avoid this and came across this error instead:

RuntimeError: 
Unknown type name 'Iterable':
  File "/home/shogg/.pyenv/versions/3.8.12/envs/mldi/lib/python3.8/site-packages/torchdyn/core/neuralde.py", line 92
    def forward(self, x:Tensor, t_span:Tensor=None, save_at:Iterable=(), args={}):
                                                            ~~~~~~~~ <--- HERE
        x, t_span = self._prep_integration(x, t_span)
        t_eval, sol =  super().forward(x, t_span, save_at, args)

The weird thing about this is that I can see that Iterable is definitely imported at the top of the file

StephenHogg commented 2 years ago

I did a bit more looking and can see that torchscript is not happy with Iterable as an annotation. Here's the list of allowable type hints, is it possible to change to one of these? Happy to write a PR if so

https://pytorch.org/docs/stable/jit_language_reference.html#supported-type

StephenHogg commented 2 years ago

Update again: removing that type annotation revealed another annotation to be incorrect and also revealed a problem in the way forward is called. Seems like you guys would probably want to do a refactor to make torchscripting safe with this package. Let me know if you're interested in doing it and I'd be happy to work with you on it, it would be great to be able to serialise these models.

Zymrael commented 2 years ago

Thanks for looking into this! I've started a refactor in this branch.

I fixed some of the typing inconsistencies and managed to push through to:

RuntimeError: 
'Tensor' object has no attribute or method 'forward'.:
  File "/home/stefano/michael/diffeqml/torchdyn/torchdyn/core/neuralde.py", line 94
    def forward(self, x:Tensor, t_span:Tensor=None, save_at:Tensor=()):
        x, t_span = self._prep_integration(x, t_span)
        t_eval, sol =  super().forward(x, t_span, save_at)
                       ~~~~~~~~~~~~~ <--- HERE
        if self.return_t_eval: return t_eval, sol
        else: return sol

Is that the error you observed?

StephenHogg commented 2 years ago

Hi @Zymrael - yes, that one is a problem because what you would do there is:

super(ODEProblem, self).forward(x, t_span, save_at)

but torchscript isn't happy with types being passed as parameters and errors if you do that fix. This implies that some of the inheritance the current factoring of the code relies on would potentially need to be addressed - this isn't necessarily a bad thing if it reduces the amount of redirection in the code a little anyway. Hope this helps, happy to keep chatting on this one if it helps.