Closed qu-gg closed 2 years ago
I don't really see why you need N separate odeint calls. Couldn't you concatenate all of them into a single state? We also allow passing in a tuple of tensors for the state, so you can do something like:
ode_func_concatenated = lambda t, state: tuple(ode_func[i](state[i])) for i in range(<num_ODEs>))
zts = odeint(ode_func_concatenated, z_init, <other args>)
I'd look into batching these hypernetworks as well, so you don't need to compute them sequentially. But on the odeint side, you can definitely just concatenate all the systems and solve once; this should help reduce some of the overhead of calling odeint.
Figured this out; thanks much for your help!
Hi!
I've got an optimization question related to speeding up or re-organizing odeint calls for performance purposes.
My current setup is that I have a hypernetwork that outputs the parameters of an ODE dynamics MLP - using some portion of the input sequence. As such during a batch, for N samples, I have N unique dynamics functions that requires N separate odeint calls.
I've tried a few approaches to try and solve this, including parallelization with pytorch multiprocessing (which has computation graph issues for the gradient), concatenating the initial states of each sample into a NxDim vector and doing one odeint call (in which the forward function separates the vector and does a loop over each dynamics MLP), and simply increasing the integration tolerances for faster per-odeint speeds. I'm currently using the non-adjoint method since memory is not a concern.
As I've had no success with this for quite some time, I'd figure to shoot a message here to see if you all have any thoughts or ideas for potential optimization directions.
Here's what the current set up looks like for clarity:
Thanks much in advance!