google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.52k stars 195 forks source link

Tweaked to not hang on adaptive solvers. #33

Closed patrick-kidger closed 3 years ago

patrick-kidger commented 3 years ago

From the discussion in #28. Figured out what was going on and fixed it.

I can go into details if you're curious.

(And the tests in test_sdeint pass - suitably modified by forcing the Brownian motions to be None to not throw an error on SRK; but I've not included that in the PR as I think you're working on those files.)

The extra step_ function introduced is only intended as a placeholder until we've got #31 merged in, and then we can tweak all the solvers' interfaces to that of step_.

lxuechen commented 3 years ago

Thanks for the fix! Could you describe a little more of what you think the problem is?

Also, I think we should merge #21 first as it was completed last night, and also merge #31 before this.

patrick-kidger commented 3 years ago

Sure thing. Previously we had a setup like:

while a < b:
    step = ...
    step = min(step, b - a)
    ...
    a += step

eventually a gets close enough to b that the min kicks in and step = b - a. You'd expect to then get a += step implying that a <- a + (b - a) = b and the while loop to complete. I think what actually happens is that a + (b - a) < b due to rounding errors, so the loop never terminates. An explicit example for which this condition is possible: a = 1e30, b = 1. Then the LHS is 0 but the RHS is 1. [Obviously this example can't be exactly what's going on because we additionally have a < b in our case, but I expect it's something like this.]

The trick is just not to work with delta-ts; instead explicitly work with the points that you're stepping between.

And ah, sorry I didn't see that #21 is done - I think I might not get notifications from mentions in edits? (Idk.) I'll start reviewing it.

lxuechen commented 3 years ago

Since #21 is merged, I think it'd be nice for us to have this tweak done more completely. In particular, I'd recommend changing the step function for the solvers to be def step(self, t0, y0, t1). Alternatively, we could stick with the wrapper step function, but I think we should have a wrapper for both step and step_logqp just to mirror the functionality. This might help me in the future when I fix the logqp component.

patrick-kidger commented 3 years ago

Changing all the solvers' interfaces is what I'm thinking too. I'm hoping to get this PR done tomorrow.

patrick-kidger commented 3 years ago

@lxuechen Sorry this took so long. All working now I think. You'll see a couple other changes in there too; a couple other unrelated things needed tweaking for the tests to pass.

lxuechen commented 3 years ago

@lxuechen Sorry this took so long. All working now I think. You'll see a couple other changes in there too; a couple other unrelated things needed tweaking for the tests to pass.

No worries! And thanks for coming up with the fix so quickly in the first place. Going through everything now, I see that most of the changes are in a separate PR, so I'll stick with that one for now. Are the two basically doing the same thing?

patrick-kidger commented 3 years ago

Ah, the other PR is just derived from this one is all. (I.e. this plus some other stuff.) I kept them separate as they're addressing different things but in retrospect there was no real need for that.