Open rsanchezgarc opened 7 months ago
Hello,
This seems to be a problem with Torchdyn. A workaround might be to use torchdiffeq instead. You could also write your own custom Euler integration method. Unfortunately, as the NeurIPS deadline is only one month away, I will not have time to look to this issue especially as it is not really related to TorchCFM but rather to Torchdyn.
Best, K.
I would implement this by inheriting from the SuperResModel class i.e.
class MySuperResModel(SuperResModel):
def forward(t, x):
return super().forward(t, x, model_kwargs)
I think we do this in the conditional example
Hi,
I am trying to use your package with the
from torchcfm.models.unet.unet import SuperResModel
and other custom models that have kwargs in their forward method, but I think that the NeuralODE.trajectory method is not compatible with those models?Could you please try to add a model_kwargs parameters to NeuralODE.trajectory, NeuralODE.forward, etc?
Thanks!