DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.35k stars 125 forks source link

NeuralODE namespace module naming collision #140

Closed laserkelvin closed 2 years ago

laserkelvin commented 2 years ago

Bug description

The NeuralODE class fails to run when passing an abstract VectorField that just so happens to have a layer or submodule named u.

On this line, the _prep_integration function checks for children of the vector field for a u module, and modifies it expecting something completely different.

Attempting to run the forward method of NeuralODE will trigger this partially unhelpful exception, which makes you think it's a torch problem:

In [8]: ode(X)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-6911594fe32a> in <module>
----> 1 ode(X)

/scratch2/kinlongk/.conda/specgraph/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/scratch2/kinlongk/.conda/specgraph/lib/python3.9/site-packages/torchdyn/core/neuralde.py in forward(self, x, t_span)
     90 
     91     def forward(self, x:Tensor, t_span:Tensor=None):
---> 92         x, t_span = self._prep_integration(x, t_span)
     93         t_eval, sol =  super().forward(x, t_span)
     94         if self.return_t_eval: return t_eval, sol

/scratch2/kinlongk/.conda/specgraph/lib/python3.9/site-packages/torchdyn/core/neuralde.py in _prep_integration(self, x, t_span)
     86             if hasattr(module, 'u'):
     87                 self.controlled = True
---> 88                 module.u = x[:, excess_dims:].detach()
     89         return x, t_span
     90 

/scratch2/kinlongk/.conda/specgraph/lib/python3.9/site-packages/torch/nn/modules/module.py in __setattr__(self, name, value)
   1218             elif modules is not None and name in modules:
   1219                 if value is not None:
-> 1220                     raise TypeError("cannot assign '{}' as child module '{}' "
   1221                                     "(torch.nn.Module or None expected)"
   1222                                     .format(torch.typename(value), name))

TypeError: cannot assign 'torch.FloatTensor' as child module 'u' (torch.nn.Module or None expected)

Step to Reproduce

Steps to reproduce the behavior:

  1. Define a simple MLP for our vector field with a submodule named u:
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.u = nn.Linear(8, 8)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.u(X)
  1. Wrap it with NeuralODE, and call forward with any arguments:
import torch
from torchdyn.nn import DepthCat
from torchdyn.core import NeuralODE

vf = SimpleMLP()
X = torch.rand(128, 8)

ode = NeuralODE(vf)
# the line below will break
t_eval, yhat = ode(X)

Expected behavior

There shouldn't be a collision with the NeuralODE namespace; I suggest renaming the expected object that _prep_integration looks for to be "private", i.e. hasattr(module, "_u") just to minimize the chance of this happening, or renaming it to something more descriptive.

If you rename self.u to self._u in SimpleMLP above, the code works.

Additional context

This bug exists as of torchdyn==1.0, and remains on main branch (per my permalink to the Git).

Zymrael commented 2 years ago

Thanks @laserkelvin! We renamed u to ._control as a quick fix. We are still planning to extend DataControl to be more flexible in the future (passing for example "trajectory" controls for Neural Controlled Differential Equations).