DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.33k stars 124 forks source link

shape of gradient does not match the parameter shape in vector field while using adjoint method #206

Open ljxw88 opened 8 months ago

ljxw88 commented 8 months ago

In torchdyn -> numerics -> sensitivity.py

function _gather_odefunc_adjoint(), line 71:

dμ = torch.cat([el.flatten() if el is not None else torch.zeros(1) for el in dμ], dim=-1)

should be fixed by

param_shapes = [p.shape for p in vf.parameters()]
dμ = torch.cat([el.flatten() if el is not None else torch.zeros(param_shapes[i]).to(t.device).flatten() for i, el in enumerate(dμ)], dim=-1)

otherwise, the shape of gradient (torch.zeros(1)) does not match the parameter shape in vector field.