patrick-kidger / torchcde

Differentiable controlled differential equation solvers for PyTorch with GPU support and memory-efficient adjoint backpropagation.
Apache License 2.0
421 stars 43 forks source link

Integrate the ODE function of the CDE system to infinity #48

Open Basicallyimwell opened 1 year ago

Basicallyimwell commented 1 year ago

Hi Patrick I do really amazed and appreciate you and your team's work on handling these dynamic systems. As a medical student, what I think might be wrong from a mathematician's perspective so please forgive me if I asked silly questions.

My work in a nutshell is to estimate the signal intensity of Dynamic Contrasted MRI images for all(any) time points. The latent space in my model is to represent human body dynamics that are controlled by a NeuralCDE system (you may regard them as interpolated signal intensities). And Interestingly, specific to this study, you can imagine that the initial latent space z0 is more or less similar with zT when T is large enough (like 2 days after, all the contrast should already be eliminated from body, versus contrast has not yet arrived in z0).

From my understanding, CDE takes the ODE func (a neural network) and control system X into a Vector Field first, then pass the Vector Field as the func for torchdiffeq.odeint(). I am still considering the optimal (clinically justified) method for initializing a reliable latent space, while at this moment, I would only like to integrate the ODE func alone (Not the vector field) to infinity or a relatively large number outside my model to get the zT and thus the initial latent space (the same ODE func that will use in CDE)

I would like to ask for your valuable comments on it, of great if you could give me some suggestions on how to integrate it (elegantly) rather than I set a large number for it. I have also read the paper "Deep Equilibrium Model" that I am not sure if there fixed point solver can be applied to ODE functions. Do you have any paper/topic suggested for me that I can further research the possibility of doing that? I am looking forward to receiving your reply.

The last part is only my appreciation to you sir! I like your "mentally substitute the above treatment throughout" written in the "textbook" cuz I did it a lot haha. Human body follows a way more complex dynamic system that we could often regard as black box. I do believe your work could significantly benefit the root-finding of these unknown body dynamics in clinical field. Big applause!

patrick-kidger commented 1 year ago

Hey there! Thank you, I'm glad you like my work.

For integrating to a steady state -- indeed, it's possible to do much better than simply set a large number. The idea is simply to set a tolerance ε and stop integrating once |dy/dt| < ε.

If you'd like to try this then I recommend using Diffrax, see this example for an example of solving to steady state. (torchcde was just a small research project and does not support the above feature -- Diffrax is now always my recommend choice, and it provides the production-quality integrators we now use).

Basicallyimwell commented 1 year ago

Hi Patrick, it sounds good to me! I will have a deeper look into those codes and see if I can make an equivalent implementation with my current setup. (There should be some troubles for me to install jaxlib_cuda on windows at this stage ha).

Thank you for answering and I wish you have a nice day!