thibmonsel / torchdde

Neural network compatible DDEs
https://thibmonsel.github.io/torchdde/
Apache License 2.0
9 stars 0 forks source link

integrating on different `ts` might result into different `ys` shape #21

Closed thibmonsel closed 7 months ago

thibmonsel commented 7 months ago

MWE :


def simple_ode(t,y,args):
    return - 2.0 * y

ts = torch.linspace(0, 10, 101) -> ys shape [..., 101, 3]
ts = torch.linspace(0, 5, 51) -> ys shape [..., 52, 3]

y0 = torch.rand((10, 3))
with torch.no_grad():
    ys = integrate(simple_ode, RK2(), ts, y0,None, discretize_then_optimize=True)

print('ys.shape',ys.shape)
model = SimpleNODE()
lossfunc = nn.MSELoss()
opt = torch.optim.Adam(model.parameters(), lr=0.05, weight_decay=0)

for _ in range(2000):
    opt.zero_grad()
    ret = integrate(model, RK2(), ts, y0, None)
    loss = lossfunc(ret, ys)
    print(ts.shape, ys.shape)
    loss.backward()
    opt.step()
    if loss < 10e-8:
        break

assert torch.allclose(ys, ret, atol=0.01, rtol=0.01)
thibmonsel commented 7 months ago

fixed with #22