Open varunagrawal opened 2 years ago
Thanks for the Q!
The concat
approach should work just fine but I agree that it's not necessarily the most transparent.
Using a pytree
based approach is likely most flexible in the long run, though, so that's also something we're keeping an eye on.
A pytree
seems like a large hammer for a small nail. If there is no direct support for passing in tuples as arguments, I imagine that would be easier to add in the short term. Just a couple of checks (isinstance(x, tuple)
and isinstance(x[i], torch.Tensor)
and then continue from there.
Update here, fixed_odeint
supports the state as a dict
for now. We have yet to extend it to the adaptive solver.
Hi, I have similar needs when coding my repo based on torchdyn.
I need the function of odeint to support passing some data type as (dx, dlog(x)) for building generative models such as a continuous normalizing flow. The variable of x should be tensor of any shape, while dlog is simply a scalar. (I tried to reshape and concat these tensors into one tensor and do the reverse when calling modules. But it is harmful for grad_fn to go backward.)
Temporarily, it seems that I have to turn back to torchdiffeq, which accepts tuple data type input.
I suggest torchdyn to support tree-like tensor data type input. One of the implementation is https://github.com/opendilab/treevalue.
Additional Description
I have a network I wish to train
f(x, x_dot, theta)
wherex
andx_dot
are the inputs,theta
are the network weights. This is a slightly odd problem sincex_dot
is the corrupted derivative ofx
and I wish to train a network to give me the correctx_dot
. To solve the ODE, I need to pass inx
att=0
but the network itself doesn't usex
in its forward pass, onlyx_dot
.How would I passi n multiple arguments like this to a
NeuralODE
in torchdyn? I am guessing the way to do this is to concatenate the two so I getx_x_dot = torch.cat((x, x_dot))
but I am not sure if this is correct.In
torchdiffeq
, what I did was call the solver like sowhat would be the equivalent approach in
torchdyn
?