Closed thibmonsel closed 7 months ago
MWE :
def simple_ode(t,y,args): return - 2.0 * y ts = torch.linspace(0, 10, 101) -> ys shape [..., 101, 3] ts = torch.linspace(0, 5, 51) -> ys shape [..., 52, 3] y0 = torch.rand((10, 3)) with torch.no_grad(): ys = integrate(simple_ode, RK2(), ts, y0,None, discretize_then_optimize=True) print('ys.shape',ys.shape) model = SimpleNODE() lossfunc = nn.MSELoss() opt = torch.optim.Adam(model.parameters(), lr=0.05, weight_decay=0) for _ in range(2000): opt.zero_grad() ret = integrate(model, RK2(), ts, y0, None) loss = lossfunc(ret, ys) print(ts.shape, ys.shape) loss.backward() opt.step() if loss < 10e-8: break assert torch.allclose(ys, ret, atol=0.01, rtol=0.01)
fixed with #22
MWE :