atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.25k stars 101 forks source link

NeuralODE trajectory API is quite limiting #113

Open rsanchezgarc opened 7 months ago

rsanchezgarc commented 7 months ago

Hi,

I am trying to use your package with the from torchcfm.models.unet.unet import SuperResModel and other custom models that have kwargs in their forward method, but I think that the NeuralODE.trajectory method is not compatible with those models?

Could you please try to add a model_kwargs parameters to NeuralODE.trajectory, NeuralODE.forward, etc?

Thanks!

kilianFatras commented 7 months ago

Hello,

This seems to be a problem with Torchdyn. A workaround might be to use torchdiffeq instead. You could also write your own custom Euler integration method. Unfortunately, as the NeurIPS deadline is only one month away, I will not have time to look to this issue especially as it is not really related to TorchCFM but rather to Torchdyn.

Best, K.

atong01 commented 7 months ago

I would implement this by inheriting from the SuperResModel class i.e.

class MySuperResModel(SuperResModel):
    def forward(t, x):
         return super().forward(t, x, model_kwargs)

I think we do this in the conditional example