Open dkweiss31 opened 3 months ago
To address some failing tests re reverse mode differentiation I converted it to a while_loop, but I'm still seeing some failed tests. Converting this to a draft for now
@patrick-kidger sorry for the long delay! I think the PR is ready for review now. All tests pass except for one of the tqdm progress bar tests involving jit: I'm not at all sure what is going on there?
Additionally I wanted to draw your attention to the line I wrote on line 773:
def _save_ts_impl(ts, fn, _save_state):
def _cond_fun(__save_state):
return __save_state.saveat_ts_index < len(_save_state.ts)
where I had to use _save_state.ts
instead of ts
in the conditional check because saveat_ts_index
can already be 1 if _save_state.t0==True
. So if I used ts
, then the last entry doesn't get updated. This doesn't mirror exactly what's happening on lines 421-427, so I just wanted to briefly mention it.
Awesome! Can you rebase on top of dev
(+make that the PR target branch) and I'll do a review? :) (I think this should also fix the tqdm-jit test.)
Ok, done I think!
Addresses edge case raised in https://github.com/patrick-kidger/diffrax/issues/488 when t0 == t1 and
saveat.ts
is notNone
. Additionally ifsaveat.t0
isTrue
then those values were not updated either, which should be addressed by this PR. I've additionally included a test for this case.WRT the implementation: while a loop is not very nice since everything could in principle be done in parallel, the below did not work for the
ts
part due to dynamic slicing errors. Let me know if there is a nicer workaround I could try :)