google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.52k stars 195 forks source link

Improved heuristic for BrownianInterval's dependency tree. #40

Closed patrick-kidger closed 3 years ago

patrick-kidger commented 3 years ago

Fixes #36. What was meant to be a small speedup ended up being a Russian roulette between "small speedup" and "huge slowdown".

lxuechen commented 3 years ago

Could you explain a bit more what's changed?

Also, this doesn't seem like a complete fix to me. It is still possible to get very small step sizes, and the issue will likely bug the updated BTree and BPath in the future.

patrick-kidger commented 3 years ago

Sure. So BInterval relies on a dependency structure between queried intervals as you know. If that's the only structure that is imposed, then a problem arises in that the backward pass ends up taking O(n^2) work in the length n, because you have to recompute your way forwards from the start to get to the queried interval, then step back a bit and repeat the whole procedure. (The cacheing offsets this somewhat but doesn't change the asymptotics.) The resolution to this is to prespecify a tree-like dependency structure. This increases the cost of the forward pass from O(n) to O(nlogn), but importantly decreases the cost of the backward pass to the same. The problem now is that we have to pick a resolution to create this dependency structure down to. The optimal resolution is a function of the average step size in the solver, and the size of the cache.

The bit that's been changed is the heuristic for figuring out the average step size in the solver. Previously, it was very crude: if any given step was less than 0.1 the smallest step so far recorded, then that was taken to be the new average step size, and the dependency tree recreated wrt that step size. But because the terminal step in the new solver can now be arbitrarily small, then this final step could trigger such a (completely unnecessary) dependency tree creation.

The fix is to actually calculate the average interval length that it's been queried width, and only trigger a dependency tree creation is that gets small enough. Just one tiny step isn't enough to trigger it.

This isn't an issue that should affect BTree or BPath - it's just an undesired interaction between the updated solver and an overly naive heuristic in BInterval.

lxuechen commented 3 years ago

Ok, thanks for explaining the average step size computation part. LGTM.

I think at some point, I want to revisit the weird behavior that dt is very small at the end. This actually does affect the new BPath I wrote, since it currently relies on Cholesky, and it's proving to not be very numerically stable. Though, alternatively, I could also use the H generation formulae in James' thesis and get the same thing.