google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.52k stars 195 forks source link

Add support for Stratonovich adjoint #21

Closed lxuechen closed 3 years ago

lxuechen commented 3 years ago

Opening this up is mainly to let you know this is in progress.

The remaining stuff:

Update: It's all done. Would appreciate comments in getting the code in better shape, as although I tried to think through carefully about most code, some parts were done hastily. Now all tests pass. @patrick-kidger

This PR is really long, I'd prefer not to block/be blocked by others' work. So I'm leaving gdg_jvp for adjoint in another PR.

A couple of random things that playing around with the new code has make me think about:

Update: Additional caveats:

patrick-kidger commented 3 years ago

Excellent, thanks. If this in still in-progress then ping me when you want me to go over it.

lxuechen commented 3 years ago

A subtle issue I found was that ft.reduce(operator.eq, l) doesn't actually check if all the entries in the list are equal. A typical example would be

ft.reduce(operator.eq, [0, 0, 0])

The reason is that the first check eq will return True, and True is not eq to 0.

I fixed this for BInterval in this PR.

patrick-kidger commented 3 years ago

I'll start looking at the code now. Regarding your further bullet points:

lxuechen commented 3 years ago

Sorry I didn't realize that no new notification is sent with mention in edits. Will start new message next time.

No strong feelings on whether the interface for scalar noise is (..., d) or (..., d, 1), but I think we should pick just one and not try and support both. I'd note that the latter is what the code is currently set up for + what's currently in the README.

This seems good. I need to change the step function a bit then, as I noticed that previously it was assumed that g outputs (..., d) and bm outputs (..., 1).

Which type checks are slowing things down?

The control-flow in check_contract seems to be costly for very small problems. At some point, I think we could have something that recognizes an environment variable that disables certain checks. What do you think about this? Certainly not on the top of the priority list though.

patrick-kidger commented 3 years ago

Phew, what a lot of comments. Have fun with those...

Which type checks are slowing things down?

The control-flow in check_contract seems to be costly for very small problems. At some point, I think we could have something that recognizes an environment variable that disables certain checks. What do you think about this? Certainly not on the top of the priority list though.

I agree, not a priority right now. I'd say probably not an environment variable; if nothing else, no other part of PyTorch does this. FWIW if it's a small problem that gets solved blazingly fast anyway, then the only time we'd be concerned about this if the user is solving lots of small problems, sequentially, without batching. Which is sufficiently niche that I'm not too concerned. If we really want this then I'd suggest an additional unsafe argument that disables the checks.

lxuechen commented 3 years ago

More broadly, I think we should be able to remove return_U and return_A in __call__.

Pros: Makes the code cleaner and reduces Python overhead. Cons: Makes the bm less flexible, as its __call__ would be fixed when the object is born.

patrick-kidger commented 3 years ago

(Copy from above since you comment on it down here, for clarity) I agree the return_U stuff is a bit of a hassle. The reason I arranged it like this was so:

I'm not sure what else would satisfy those two conditions.

EDIT: Thinking about it - we could remove return_U and return_A. Then in the diagnostic BM, set levy_area_approximation='space-time' and just wrap it as e.g. sdeint(..., bm=lambda ta, tb: bm(ta, tb)[0]) for the non-SRK solvers. Sound good?

patrick-kidger commented 3 years ago

Thinking about strategy, something like this? (In order)

I can adjust the return_U/A stuff, tear out logqp and try and make some free time to do tuple-to-tensor. Meanwhile you could look at BTree levy area and/or BPath optimisation? (Or whatever else you think is most important and won't clash.)

logqp and tuple-to-tensor should simplify the code dramatically, which will make going forward ten times easier.

patrick-kidger commented 3 years ago

Leaving another comment to keep points organised: I'm guessing the reason for implementing the adjoint SDE with custom g_prod, gdg_prod (rather than those derived in the forward SDE) is for efficiency reasons? How much is gained/lost? (Could we put the derivations from the forward SDE in as a fallback?)

lxuechen commented 3 years ago

(Copy from above since you comment on it down here, for clarity) I agree the return_U stuff is a bit of a hassle. The reason I arranged it like this was so:

  • the same BM can be used with multiple schemes (as in the diagnostics)
  • the default behaviour is still to return the increment.

I'm not sure what else would satisfy those two conditions.

EDIT: Thinking about it - we could remove return_U and return_A. Then in the diagnostic BM, set levy_area_approximation='space-time' and just wrap it as e.g. sdeint(..., bm=lambda ta, tb: bm(ta, tb)[0]) for the non-SRK solvers. Sound good?

I think this is a good idea. The only concern is that I'd prefer not to have another wrapper. I think we could make bm always return a tuple (even a singleton one!) when we do interval based queries, and just let the solvers index the things needed. So it's like rather delegate the indexing to bm, we let the solver choose. WDYT?

lxuechen commented 3 years ago

Thinking about strategy, something like this? (In order)

  • Merge this PR
  • Merge #33
  • Change the return_U/A stuff as in the previous comment.
  • Leave #31 until we've figured out what's going on with it, and in particular because it introduces prod which you say is a bit more difficult for the adjoint.
  • Remove logqp: I can see that maintaining this is an absolute hassle. I think eventually it would be nice to go back in (it's a very neat idea), but hopefully in a somewhat neater way. I'm thinking that it should be possible to factor it into some kind of wrapper SDE. (Doing it efficiently, without needlessly reevaluating f and g, is probably the difficult bit.)
  • Tuple-to-tensor rewrite
  • BTree levy area

I can adjust the return_U/A stuff, tear out logqp and try and make some free time to do tuple-to-tensor. Meanwhile you could look at BTree levy area and/or BPath optimisation? (Or whatever else you think is most important and won't clash.)

logqp and tuple-to-tensor should simplify the code dramatically, which will make going forward ten times easier.

I agree with the overall plan. Two concerns I have:

patrick-kidger commented 3 years ago

I think this is a good idea. The only concern is that I'd prefer not to have another wrapper. I think we could make bm always return a tuple (even a singleton one!) when we do interval based queries, and just let the solvers index the things needed. So it's like rather delegate the indexing to bm, we let the solver choose. WDYT?

So the lambda wrapper is only used in our internal diagnostics, where I think we can accept bit more ugliness than in the public library. That said, I think the tuple solution is also a fine one; specifically I'd use a namedtuple for clearer lookup in the solvers. No strong feelings either way.

I agree with the overall plan. Two concerns I have:

* About #33: I'm not entirely convinced that this will work always, but feel free to prove me wrong if you can. The fix seems like to target the specific test as opposed to providing a fail-safe general solution. I still think stepping over the predefined region is not that bad an idea. Another way to restructure the checks is to put the time check in `check_contract`, where it is checked whether any part of `ts` lay outside of `t0` and `t1` of `bm`, so that now we don't always get a warning when we call `sdeint`.

* logqp: There's a simple way to do this by augmenting the state, but the code is likely to run slower without hacking the adjoint. The initial concern was that this would provide more strain on the user if they wanted to use this model, but I guess we could live with that once we have other optimizations in place. So I think at some point we could remove this entirely, and I'm happy to do the work just case something goes wrong.

Regarding #33, I think it should always work. I'm not sure what you mean about checking times in check_contract as this was about fixing a floating point issue; the inputs were completely valid.

Regarding g_prod btw - right, I just hadn't clocked that this was a vjp. That makes total sense.

lxuechen commented 3 years ago

On another note, I think BrownianPath needs a partial rewrite due to the conditioning problem of levy area A. I will take those files down from this PR and send in another PR at a later time.

patrick-kidger commented 3 years ago

Makes sense. Let me know once you want me to review this PR again.

lxuechen commented 3 years ago

@patrick-kidger Made some minor fixes. Placed todos at places where I think would require some additional work to get the Levy area working.

lxuechen commented 3 years ago

That said, I think the tuple solution is also a fine one; specifically I'd use a namedtuple for clearer lookup in the solvers. No strong feelings either way.

Namedtuple would be much slower than tuple though.

I'm not sure what you mean about checking times in check_contract as this was about fixing a floating point issue; the inputs were completely valid.

I'm saying that I don't see a problem in letting the solve query times that are slightly outside the prespecified range. If you're worried about the bm always spitting out warnings for that in call, we could remove that and only do a time check in check_contract. And in that case, it wouldn't be producing warnings if the user does everything correctly.

patrick-kidger commented 3 years ago

I'll start going over it.

Fair enough that namedtuple may be too slow. (Although idk how much it matters in the grand scheme of things.) I'd suggest a custom class with __slots__ then.

So the reason that the BM spits out warnings in __call__ is that querying outside the specified range simply isn't defined for BrownianInterval; the data structure just isn't set up to handle that. (I suspect you could make a similar argument for BrownianTree? I've not looked too closely at exactly what you've done for that.) I don't see a way around that. Regardless, I'm more confident in #33 than you seem to be! I think it does the right thing in all scenarios.

patrick-kidger commented 3 years ago

So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired.

(Except for gdg_prod_default not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)

lxuechen commented 3 years ago

So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired.

(Except for gdg_prod_default not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)

Let's do general noise for now. And it should be sufficient if you want to run certain experiments. I can do the small fixes for other noise types in another PR, but it would require quite some work, mostly regarding testing.

At this moment, I'm more interesting in getting the basic ingredients working for the methods you guys proposed, as opposed to adding features just for completeness.

patrick-kidger commented 3 years ago

So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired. (Except for gdg_prod_default not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)

Let's do general noise for now. And it should be sufficient if you want to run certain experiments. I can do the small fixes for other noise types in another PR, but it would require quite some work, mostly regarding testing.

At this moment, I'm more interesting in getting the basic ingredients working for the methods you guys proposed, as opposed to adding features just for completeness.

I'm not asking for them, dw. I just mean - I'm confused; what part of the code doesn't already support that? I had it in my head that we couldn't, and now I can't see why I thought that.

lxuechen commented 3 years ago

This PR is ready in my opinion. Remaining things on my list include

It might be nice if you could check out BrownianInterval again for these things

I'm now convinced that the min(..., ts[-1]) sol'n could work, if there's no arithmetic operation in min, which I don't think there would be. But there might be float -> torch.tensor conversions (or the other way around), which might lead to loss of numerical precision. But if ts is always a tensor, which we enforce now, I don't think there'd be an issue.

lxuechen commented 3 years ago

So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired. (Except for gdg_prod_default not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)

Let's do general noise for now. And it should be sufficient if you want to run certain experiments. I can do the small fixes for other noise types in another PR, but it would require quite some work, mostly regarding testing. At this moment, I'm more interesting in getting the basic ingredients working for the methods you guys proposed, as opposed to adding features just for completeness.

I'm not asking for them, dw. I just mean - I'm confused; what part of the code doesn't already support that? I had it in my head that we couldn't, and now I can't see why I thought that.

I think the set of methods in adjoint_sde.py already supports additive and scalar; diagonal might need a little more thinking, but should be straightforward after I spend some time. The bulky thing that I want to leave to another PR is fixes to the tests and test problems. I would like to eventually have adjoint for Strat SDEs for different noise types numerically checked, and I'm a little worried that it might take quite some time. Atm, I don't want to block other people's work, if that make sense.

lxuechen commented 3 years ago

(Except for gdg_prod_default not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)

Without the function for computing the dg g A term, you wouldn't be able to use adjoints with the new log-ODE scheme. So this will also be on the top of my priority list.

patrick-kidger commented 3 years ago

Alright, looks good to me. I'll hit the merge button.

Everything you've said makes sense, in particular the to-do list. I'll get my part of that done. (One extra thing for whichever of us gets to it first: BrownianTree also needs Levy area support.)