DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 130 forks source link

Passing in multiple arguments #137

Closed varunagrawal closed 2 weeks ago

varunagrawal commented 2 years ago

Additional Description

I have a network I wish to train f(x, x_dot, theta) where x and x_dot are the inputs, theta are the network weights. This is a slightly odd problem since x_dot is the corrupted derivative of x and I wish to train a network to give me the correct x_dot. To solve the ODE, I need to pass in x at t=0 but the network itself doesn't use x in its forward pass, only x_dot.

How would I pass in multiple arguments like this to a NeuralODE in torchdyn? I am guessing the way to do this is to concatenate the two so I get x_x_dot = torch.cat((x, x_dot)) but I am not sure if this is correct.

In torchdiffeq, what I did was call the solver like so

class Network(nn.Module):
    def forward(t, args):
        x, x_dot = args
        return self.mlp(x_dot)

y = odeint_adjoint(network, (x_i, x_dot), t_span)

what would be the equivalent approach in torchdyn?

joglekara commented 2 years ago

Thanks for the Q!

The concat approach should work just fine but I agree that it's not necessarily the most transparent.

Using a pytree based approach is likely most flexible in the long run, though, so that's also something we're keeping an eye on.

varunagrawal commented 2 years ago

A pytree seems like a large hammer for a small nail. If there is no direct support for passing in tuples as arguments, I imagine that would be easier to add in the short term. Just a couple of checks (isinstance(x, tuple) and isinstance(x[i], torch.Tensor) and then continue from there.

joglekara commented 2 years ago

Update here, fixed_odeint supports the state as a dict for now. We have yet to extend it to the adaptive solver.

zjowowen commented 6 months ago

Hi, I have similar needs when coding my repo based on torchdyn.

I need the function of odeint to support passing some data type as (dx, dlog(x)) for building generative models such as a continuous normalizing flow. The variable of x should be tensor of any shape, while dlog is simply a scalar. (I tried to reshape and concat these tensors into one tensor and do the reverse when calling modules. But it is harmful for grad_fn to go backward.)

Temporarily, it seems that I have to turn back to torchdiffeq, which accepts tuple data type input.

I suggest torchdyn to support tree-like tensor data type input. One of the implementation is https://github.com/opendilab/treevalue.