Closed lxuechen closed 3 years ago
Also, tests/__init__.pyc
shouldn't be included in the commit.
Overall, things are looking so much cleaner now. These changes make me very happy.
Just spotted, both sdeint
and its adjoint have the docstring for names
still specifying the possibility of prior drift.
Just spotted, both
sdeint
and its adjoint have the docstring fornames
still specifying the possibility of prior drift.
Done.
This PR looks good to me. I've not hit the merge button as I assume the plan is to merge #40 first, so that the diagnostics can be run.
Thanks for the reviews @patrick-kidger! All rates preserved after running the diagnostics. Merging now.
Contents in this PR:
y0
and tuple-basedsde
forsdeint
.ForwardSDE
that has the correct versions off
,g_prod
,gdg_prod
, anddg_ga_jvp
(per our email discussion).AdjointSDE
take inForwardSDE
and use the latter'sg_prod
.g_prod
of the former.AdjointSDE
take in (and output) flattened tensors and perform flattening/unflattening inside the registered functions.adjoint.py
torchdiffeq
, where the_TupleFunc
construction is in_check_inputs
.sdeint.integrate
only supports tensor-basedy0
andsde
, and we should enforce this usage for now.sdeint_adjoint
.diagnostics
.logqp
support.sdeint
inconsistency. Hence, resolving #27.The rate diagnostics files run, but I haven't been able to obtain rates due to #36. I need to run these when this is fixed. All tests pass.
Overall, this PR should also make adding
dg_ga_jvp
for the adjoint easier.