rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

forward() takes 2 positional arguments but 3 were given #220

Closed tomzhu0225 closed 1 year ago

tomzhu0225 commented 1 year ago

I have my NODE set up like this:

class NODE(nn.Module):
    def __init__(self, num_params):
        super(NODE, self).__init__()
        self.num_params=num_params
        self.fc1 = nn.Linear(p + num_params, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, p + num_params)
    def forward(self, x):
        #x = torch.cat([x, params], dim=1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        dx = self.fc3(x)
        dx[:self.num_params]=0
        return dx

when I try to call the odeint in training:

ode_solution=odeint(node,x0,t1)

The error would popup:


  File "D:\projects\NODE_for_accelerator\NODE_LINAC\structure.py", line 178, in <module>
    train_loss = train(encoder, decoder, node, train_loader, optimizer, criterion)

  File "D:\projects\NODE_for_accelerator\NODE_LINAC\structure.py", line 143, in train
    ode_solution=odeint(node,x0,t1)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torchdiffeq\_impl\adjoint.py", line 198, in odeint_adjoint
    ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol,

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torchdiffeq\_impl\adjoint.py", line 25, in forward
    ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torchdiffeq\_impl\odeint.py", line 77, in odeint
    solution = solver.integrate(t)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torchdiffeq\_impl\solvers.py", line 28, in integrate
    self._before_integrate(t)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torchdiffeq\_impl\rk_common.py", line 161, in _before_integrate
    f0 = self.func(t[0], self.y0)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torchdiffeq\_impl\misc.py", line 189, in forward
    return self.base_func(t, y)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torchdiffeq\_impl\misc.py", line 189, in forward
    return self.base_func(t, y)

  File "C:\Users\tomkeen\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)

TypeError: forward() takes 2 positional arguments but 3 were given

if I try to reduce the number of inputs to odeint:


  File "D:\projects\NODE_for_accelerator\NODE_LINAC\structure.py", line 178, in <module>
    train_loss = train(encoder, decoder, node, train_loader, optimizer, criterion)

  File "D:\projects\NODE_for_accelerator\NODE_LINAC\structure.py", line 143, in train
    ode_solution=odeint(node,x0)

TypeError: odeint_adjoint() missing 1 required positional argument: 't'

I am new to pytorch and I really don't know what is happening. I would be great if you can tell me what's wrong. THX!

tomzhu0225 commented 1 year ago

I understand now.

    def forward(self,t, x): #although t is useless right now, it would be useful for odeint solver
        #x = torch.cat([x, params], dim=1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        dx = self.fc3(x)
        dx[:self.num_params]=0
        return dx
achvsujsbxkh commented 1 year ago

Can you tell me how you finally solved the problem(TypeError: forward() takes 2 positional arguments but 3 were given)? Thank you very much

tomzhu0225 commented 1 year ago

HI, I resolve the problem by add a t from

+  def forward(self, x)

to

+  def forward(self,t, x)