Closed tomzhu0225 closed 1 year ago
I understand now.
def forward(self,t, x): #although t is useless right now, it would be useful for odeint solver
#x = torch.cat([x, params], dim=1)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
dx = self.fc3(x)
dx[:self.num_params]=0
return dx
Can you tell me how you finally solved the problem(TypeError: forward() takes 2 positional arguments but 3 were given)? Thank you very much
HI, I resolve the problem by add a t from
+ def forward(self, x)
to
+ def forward(self,t, x)
I have my NODE set up like this:
when I try to call the odeint in training:
The error would popup:
if I try to reduce the number of inputs to odeint:
I am new to pytorch and I really don't know what is happening. I would be great if you can tell me what's wrong. THX!