google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.52k stars 195 forks source link

Added double adjoint #49

Closed patrick-kidger closed 3 years ago

patrick-kidger commented 3 years ago

Just creating a draft PR to let you know I'm working on double-backward through the adjoint.

In informal tests it seems to work as expected. The only thing left to do is write some formal tests. I'll put them in test_adjoint.py and switch that file over to pytest while I'm at it.

lxuechen commented 3 years ago

Thanks for pulling off a draft!

Some general comments

In informal tests it seems to work as expected.

What are the informal tests? Are there finite-difference tests? Or did we directly use torch.autograd.gradgradcheck? Does the old tests still pass? If you could also push the companion tests, then I think I'd be able to help more from running and tweaking the code.

Hopefully asking all of these question at once doesn't seem too annoying, but I just want to gain a better understanding of what's happening here. Also, we don't have to do everything at once or address all of the questions above at once, but it's good to know what works and what doesn't so that we can place meaningful TODOs.

patrick-kidger commented 3 years ago

But that's a detail that's beyond AdjointSDE's scope. AdjointSDE requires that certain conditions are satisfied, so it should enforce them. Practically: if that detail ever changes down the line then I'd prefer to know about it without things silently going wrong. The logic here can get quite hairy when you start getting into adjoints-of-adjoints.

Old single-adjoint tests still pass.

lxuechen commented 3 years ago

Including t and v: AFAIK both t and v should never require gradients in the current set-up.

Per current usage, I can't really come up with a case where t and v are not leaf variables, so I would prefer to not let _get_state check these just to reduce redundancy. The idea is that this PR adds the requires grad checks for ts upfront, and that v should always be generated by the Brownian motion. I think I would leave some comments documenting the situation here if we're really feeling unsure.

implicit contexts: this isn't true in the double-adjoint case, when f_uncorrected of the double adjoint sets enable_grad and then calls f_uncorrected of the first adjoint.

You're right, thanks for explaining.

Naive backprop: yeah, it should probably be added to the docstring for sdeint_adjoint. I'll do that.

Thanks for the fix for documentation!

Additionally, could we do a numerical test? Ideally, just replacing the torch.autograd.gradcheck here with torch.autograd.gradgradcheck (and slightly modifying the surrounding) should tell us something.

As a side note, I'll be quite busy Tues and Wed this week. But if you're willing to wait a little, then I can do the tests on Thursdays and Friday. In any case, I think testing is something that must be accomplished. Obviously you're a pretty good coder and software engineer, but I wouldn't feel the code is complete or absolutely trustworthy before we get those tests.

patrick-kidger commented 3 years ago

I agree that every current case involves t and v being leaf variables already. (Moreover non-gradient-requiring variables.) I'm quite strongly in favour of maintaining the checks though, for what I think is the same reason that you brought up the discussion on requires_grad: we might not be doing it now, but it's an implicit assumption that will silently do the wrong thing if we ever change that assumption in the future.

Regarding tests - I feel like you think I'm a bit cavalier about tests! Don't worry, I appreciate their importance for any software project; I just tend to put them in a bit later than you. For this PR specifically, it sounds like you're offering to write the tests? If you get time and are happy to do it, that would be great.

lxuechen commented 3 years ago

I agree that every current case involves t and v being leaf variables already. (Moreover non-gradient-requiring variables.) I'm quite strongly in favour of maintaining the checks though, for what I think is the same reason that you brought up the discussion on requires_grad: we might not be doing it now, but it's an implicit assumption that will silently do the wrong thing if we ever change that assumption in the future.

Regarding tests - I feel like you think I'm a bit cavalier about tests! Don't worry, I appreciate their importance for any software project; I just tend to put them in a bit later than you. For this PR specifically, it sounds like you're offering to write the tests? If you get time and are happy to do it, that would be great.

I appreciate the detailed thoughts, though I still believe there's a concrete difference between the issue I mentioned and the argument about t and v. I think we currently do have a use case of backprop through solver, and the issue I mentioned could affect this behavior. On the other hand, I couldn't really come up with a case where gradients wrt t and v need to be taken.

patrick-kidger commented 3 years ago

Hmm I don't quite follow. Gradients wrt t and v are a separate issue to this breaking Milstein. (If that's what you're saying then I agree!) I agree that we never need gradients wrt t or v, I'm just quite assert-happy as a way of making sure things aren't silently misbehaving. I agree that this PR introduces an issue wrt Milstein (whilst fixing the issue wrt leaking graphs).

Reflecting on the main (gradient) issue, I am actually inclined to just set requires_grad = torch.is_grad_enabled() and leave it at that. This does leak graphs, but I don't think the overhead should be too large. I think the alternate fix, or manually checking the graph, is probably a bit too magic.

lxuechen commented 3 years ago

Hmm I don't quite follow. Gradients wrt t and v are a separate issue to this breaking Milstein. (If that's what you're saying then I agree!)

Yes, I'm trying to argue that these are separate situations.

I agree that we never need gradients wrt t or v, I'm just quite assert-happy as a way of making sure things aren't silently misbehaving.

I still don't think it's necessary to check t or v, as I'm unable to come up with any likely scenario where these variables would require grad. Given that this check also adds complexity to the code and a tiny overhead, I'm not sure of the advantage.

I agree that this PR introduces an issue wrt Milstein (whilst fixing the issue wrt leaking graphs).

Reflecting on the main (gradient) issue, I am actually inclined to just set requires_grad = torch.is_grad_enabled() and leave it at that. This does leak graphs, but I don't think the overhead should be too large. I think the alternate fix, or manually checking the graph, is probably a bit too magic.

I agree with this fix, i.e. just leaving requires_grad = torch.is_grad_enabled().

Side note, I think this PR might be ready to be converted to review-mode from draft mode modulo the t v grad check discussion.

patrick-kidger commented 3 years ago

Changed requires_grad.

Regarding the asserts: as I understand it, you want to remove the assert t.is_leaf and assert v.is_leaf lines from adjoint_sde.py? I think it's important that AdjointSDE (or indeed any other library component) is self-contained; explicitly checked assumptions are better than implicit assumptions, especially in cases like this where a failure of that assumption will fail silently rather than fail loudly.

I can't think of a use case for t or v needing gradients either - but that could change not just because of some unforeseen feature, but because of a bug. (And the cost/complexity of two asserts is very low.)

lxuechen commented 3 years ago

Ok, let's leave the checks as they are for now. Happy to have this merged and thanks again for the great work in getting the adjoint-adjoints working!