DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 130 forks source link

__init__() got an unexpected keyword argument 'func' #205

Open ClaudiaShu opened 1 year ago

ClaudiaShu commented 1 year ago

Hi, I am trying to implement neuralSDE using torchdyn, my code is:

# drift_func
f = nn.Sequential(...)
# diffusion function
g = nn.Sequential(...)                           
self.func = NeuralSDE(f, g, solver=args.solver, rtol=args.rtol, atol=args.atol)

But I got this error:

TypeError: __init__() got an unexpected keyword argument 'func'

I checked the source code and it seems that the initialisation of the model class is different from what input it takes.

class NeuralSDE(SDEProblem, pl.LightningModule):
    def __init__(self, drift_func, diffusion_func, noise_type ='diagonal', sde_type = 'ito', order=1,
                 sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='srk',
                 atol=1e-4, rtol=1e-4, ds = 1e-3, intloss=None):
        super().__init__(func=SDEFunc(f=drift_func, g=diffusion_func, order=order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver,
                                      atol=atol, rtol=rtol)

What should be given to the model as the input or is there a bug that remains to fix?

Thanks in advance!

ZhangKaly commented 5 months ago

Got the same! Has it been resolved? Curious about the solution to this!

ricardo-hopker commented 3 months ago

looks like SDEProblem is not implemented yet:

class SDEProblem(nn.Module):
    def __init__(self):
        "Extension of `ODEProblem` to SDE"
        super().__init__()
        raise NotImplementedError("Hopefully soon...")