rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

Optimization question for input-conditioned dynamics #207

Closed qu-gg closed 2 years ago

qu-gg commented 2 years ago

Hi!

I've got an optimization question related to speeding up or re-organizing odeint calls for performance purposes.

My current setup is that I have a hypernetwork that outputs the parameters of an ODE dynamics MLP - using some portion of the input sequence. As such during a batch, for N samples, I have N unique dynamics functions that requires N separate odeint calls.

I've tried a few approaches to try and solve this, including parallelization with pytorch multiprocessing (which has computation graph issues for the gradient), concatenating the initial states of each sample into a NxDim vector and doing one odeint call (in which the forward function separates the vector and does a loop over each dynamics MLP), and simply increasing the integration tolerances for faster per-odeint speeds. I'm currently using the non-adjoint method since memory is not a concern.

As I've had no success with this for quite some time, I'd figure to shoot a message here to see if you all have any thoughts or ideas for potential optimization directions.

Here's what the current set up looks like for clarity: image

Thanks much in advance!

rtqichen commented 2 years ago

I don't really see why you need N separate odeint calls. Couldn't you concatenate all of them into a single state? We also allow passing in a tuple of tensors for the state, so you can do something like:

ode_func_concatenated = lambda t, state: tuple(ode_func[i](state[i])) for i in range(<num_ODEs>))
zts = odeint(ode_func_concatenated, z_init, <other args>)

I'd look into batching these hypernetworks as well, so you don't need to compute them sequentially. But on the odeint side, you can definitely just concatenate all the systems and solve once; this should help reduce some of the overhead of calling odeint.

qu-gg commented 2 years ago

Figured this out; thanks much for your help!