patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 132 forks source link

Save fix for `t0==t1` #494

Open dkweiss31 opened 3 months ago

dkweiss31 commented 3 months ago

Addresses edge case raised in https://github.com/patrick-kidger/diffrax/issues/488 when t0 == t1 and saveat.ts is not None. Additionally if saveat.t0 is True then those values were not updated either, which should be addressed by this PR. I've additionally included a test for this case.

WRT the implementation: while a loop is not very nice since everything could in principle be done in parallel, the below did not work for the ts part due to dynamic slicing errors. Let me know if there is a nicer workaround I could try :)

if subsaveat.ts is not None:
    _ts = subsaveat.ts
    save_idx = save_state.save_index
    ts = save_state.ts.at[save_idx: save_idx + len(_ts)].set(_ts)
    _ys = [subsaveat.fn(t1, yfinal, args)] * len(_ts)
    ys = save_state.ys.at[save_idx: save_idx + len(_ts)].set(_ys)
    save_state = SaveState(
         saveat_ts_index=save_idx + len(_ts),
         ts=ts,
         ys=ys,
         save_index=save_idx + len(_ts),
     )
dkweiss31 commented 3 months ago

To address some failing tests re reverse mode differentiation I converted it to a while_loop, but I'm still seeing some failed tests. Converting this to a draft for now

dkweiss31 commented 1 week ago

@patrick-kidger sorry for the long delay! I think the PR is ready for review now. All tests pass except for one of the tqdm progress bar tests involving jit: I'm not at all sure what is going on there?

Additionally I wanted to draw your attention to the line I wrote on line 773:

def _save_ts_impl(ts, fn, _save_state):
    def _cond_fun(__save_state):
        return __save_state.saveat_ts_index < len(_save_state.ts)

where I had to use _save_state.ts instead of ts in the conditional check because saveat_ts_index can already be 1 if _save_state.t0==True. So if I used ts, then the last entry doesn't get updated. This doesn't mirror exactly what's happening on lines 421-427, so I just wanted to briefly mention it.

patrick-kidger commented 1 week ago

Awesome! Can you rebase on top of dev (+make that the PR target branch) and I'll do a review? :) (I think this should also fix the tqdm-jit test.)

dkweiss31 commented 1 week ago

Ok, done I think!