Closed lxuechen closed 3 years ago
Excellent, thanks. If this in still in-progress then ping me when you want me to go over it.
A subtle issue I found was that ft.reduce(operator.eq, l)
doesn't actually check if all the entries in the list are equal. A typical example would be
ft.reduce(operator.eq, [0, 0, 0])
The reason is that the first check eq will return True
, and True
is not eq to 0.
I fixed this for BInterval in this PR.
I'll start looking at the code now. Regarding your further bullet points:
(..., d)
or (..., d, 1)
, but I think we should pick just one and not try and support both. I'd note that the latter is what the code is currently set up for + what's currently in the README.Sorry I didn't realize that no new notification is sent with mention in edits. Will start new message next time.
No strong feelings on whether the interface for scalar noise is (..., d) or (..., d, 1), but I think we should pick just one and not try and support both. I'd note that the latter is what the code is currently set up for + what's currently in the README.
This seems good. I need to change the step
function a bit then, as I noticed that previously it was assumed that g
outputs (..., d) and bm outputs (..., 1).
Which type checks are slowing things down?
The control-flow in check_contract
seems to be costly for very small problems. At some point, I think we could have something that recognizes an environment variable that disables certain checks. What do you think about this? Certainly not on the top of the priority list though.
Phew, what a lot of comments. Have fun with those...
Which type checks are slowing things down?
The control-flow in
check_contract
seems to be costly for very small problems. At some point, I think we could have something that recognizes an environment variable that disables certain checks. What do you think about this? Certainly not on the top of the priority list though.
I agree, not a priority right now. I'd say probably not an environment variable; if nothing else, no other part of PyTorch does this. FWIW if it's a small problem that gets solved blazingly fast anyway, then the only time we'd be concerned about this if the user is solving lots of small problems, sequentially, without batching. Which is sufficiently niche that I'm not too concerned.
If we really want this then I'd suggest an additional unsafe
argument that disables the checks.
More broadly, I think we should be able to remove return_U
and return_A
in __call__
.
Pros: Makes the code cleaner and reduces Python overhead.
Cons: Makes the bm less flexible, as its __call__
would be fixed when the object is born.
(Copy from above since you comment on it down here, for clarity)
I agree the return_U
stuff is a bit of a hassle. The reason I arranged it like this was so:
I'm not sure what else would satisfy those two conditions.
EDIT: Thinking about it - we could remove return_U
and return_A
. Then in the diagnostic BM, set levy_area_approximation='space-time'
and just wrap it as e.g. sdeint(..., bm=lambda ta, tb: bm(ta, tb)[0])
for the non-SRK solvers. Sound good?
Thinking about strategy, something like this? (In order)
return_U/A
stuff as in the previous comment.prod
which you say is a bit more difficult for the adjoint.I can adjust the return_U/A
stuff, tear out logqp and try and make some free time to do tuple-to-tensor. Meanwhile you could look at BTree levy area and/or BPath optimisation? (Or whatever else you think is most important and won't clash.)
logqp and tuple-to-tensor should simplify the code dramatically, which will make going forward ten times easier.
Leaving another comment to keep points organised: I'm guessing the reason for implementing the adjoint SDE with custom g_prod
, gdg_prod
(rather than those derived in the forward SDE) is for efficiency reasons? How much is gained/lost? (Could we put the derivations from the forward SDE in as a fallback?)
(Copy from above since you comment on it down here, for clarity) I agree the
return_U
stuff is a bit of a hassle. The reason I arranged it like this was so:
- the same BM can be used with multiple schemes (as in the diagnostics)
- the default behaviour is still to return the increment.
I'm not sure what else would satisfy those two conditions.
EDIT: Thinking about it - we could remove
return_U
andreturn_A
. Then in the diagnostic BM, setlevy_area_approximation='space-time'
and just wrap it as e.g.sdeint(..., bm=lambda ta, tb: bm(ta, tb)[0])
for the non-SRK solvers. Sound good?
I think this is a good idea. The only concern is that I'd prefer not to have another wrapper. I think we could make bm always return a tuple (even a singleton one!) when we do interval based queries, and just let the solvers index the things needed. So it's like rather delegate the indexing to bm, we let the solver choose. WDYT?
Thinking about strategy, something like this? (In order)
- Merge this PR
- Merge #33
- Change the
return_U/A
stuff as in the previous comment.- Leave #31 until we've figured out what's going on with it, and in particular because it introduces
prod
which you say is a bit more difficult for the adjoint.- Remove logqp: I can see that maintaining this is an absolute hassle. I think eventually it would be nice to go back in (it's a very neat idea), but hopefully in a somewhat neater way. I'm thinking that it should be possible to factor it into some kind of wrapper SDE. (Doing it efficiently, without needlessly reevaluating f and g, is probably the difficult bit.)
- Tuple-to-tensor rewrite
- BTree levy area
I can adjust the
return_U/A
stuff, tear out logqp and try and make some free time to do tuple-to-tensor. Meanwhile you could look at BTree levy area and/or BPath optimisation? (Or whatever else you think is most important and won't clash.)logqp and tuple-to-tensor should simplify the code dramatically, which will make going forward ten times easier.
I agree with the overall plan. Two concerns I have:
About #33: I'm not entirely convinced that this will work always, but feel free to prove me wrong if you can. The fix seems like to target the specific test as opposed to providing a fail-safe general solution. I still think stepping over the predefined region is not that bad an idea. Another way to restructure the checks is to put the time check in check_contract
, where it is checked whether any part of ts
lay outside of t0
and t1
of bm
, so that now we don't always get a warning when we call sdeint
.
logqp: There's a simple way to do this by augmenting the state, but the code is likely to run slower without hacking the adjoint. The initial concern was that this would provide more strain on the user if they wanted to use this model, but I guess we could live with that once we have other optimizations in place. So I think at some point we could remove this entirely, and I'm happy to do the work just case something goes wrong.
I think this is a good idea. The only concern is that I'd prefer not to have another wrapper. I think we could make bm always return a tuple (even a singleton one!) when we do interval based queries, and just let the solvers index the things needed. So it's like rather delegate the indexing to bm, we let the solver choose. WDYT?
So the lambda wrapper is only used in our internal diagnostics, where I think we can accept bit more ugliness than in the public library. That said, I think the tuple solution is also a fine one; specifically I'd use a namedtuple for clearer lookup in the solvers. No strong feelings either way.
I agree with the overall plan. Two concerns I have:
* About #33: I'm not entirely convinced that this will work always, but feel free to prove me wrong if you can. The fix seems like to target the specific test as opposed to providing a fail-safe general solution. I still think stepping over the predefined region is not that bad an idea. Another way to restructure the checks is to put the time check in `check_contract`, where it is checked whether any part of `ts` lay outside of `t0` and `t1` of `bm`, so that now we don't always get a warning when we call `sdeint`. * logqp: There's a simple way to do this by augmenting the state, but the code is likely to run slower without hacking the adjoint. The initial concern was that this would provide more strain on the user if they wanted to use this model, but I guess we could live with that once we have other optimizations in place. So I think at some point we could remove this entirely, and I'm happy to do the work just case something goes wrong.
Regarding #33, I think it should always work. I'm not sure what you mean about checking times in check_contract
as this was about fixing a floating point issue; the inputs were completely valid.
Regarding g_prod
btw - right, I just hadn't clocked that this was a vjp. That makes total sense.
On another note, I think BrownianPath needs a partial rewrite due to the conditioning problem of levy area A. I will take those files down from this PR and send in another PR at a later time.
Makes sense. Let me know once you want me to review this PR again.
@patrick-kidger Made some minor fixes. Placed todos at places where I think would require some additional work to get the Levy area working.
That said, I think the tuple solution is also a fine one; specifically I'd use a namedtuple for clearer lookup in the solvers. No strong feelings either way.
Namedtuple would be much slower than tuple though.
I'm not sure what you mean about checking times in check_contract as this was about fixing a floating point issue; the inputs were completely valid.
I'm saying that I don't see a problem in letting the solve query times that are slightly outside the prespecified range. If you're worried about the bm always spitting out warnings for that in call, we could remove that and only do a time check in check_contract
. And in that case, it wouldn't be producing warnings if the user does everything correctly.
I'll start going over it.
Fair enough that namedtuple may be too slow. (Although idk how much it matters in the grand scheme of things.) I'd suggest a custom class with __slots__
then.
So the reason that the BM spits out warnings in __call__
is that querying outside the specified range simply isn't defined for BrownianInterval; the data structure just isn't set up to handle that. (I suspect you could make a similar argument for BrownianTree? I've not looked too closely at exactly what you've done for that.) I don't see a way around that. Regardless, I'm more confident in #33 than you seem to be! I think it does the right thing in all scenarios.
So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired.
(Except for gdg_prod_default
not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)
So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired.
(Except for
gdg_prod_default
not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)
Let's do general noise for now. And it should be sufficient if you want to run certain experiments. I can do the small fixes for other noise types in another PR, but it would require quite some work, mostly regarding testing.
At this moment, I'm more interesting in getting the basic ingredients working for the methods you guys proposed, as opposed to adding features just for completeness.
So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired. (Except for
gdg_prod_default
not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)Let's do general noise for now. And it should be sufficient if you want to run certain experiments. I can do the small fixes for other noise types in another PR, but it would require quite some work, mostly regarding testing.
At this moment, I'm more interesting in getting the basic ingredients working for the methods you guys proposed, as opposed to adding features just for completeness.
I'm not asking for them, dw. I just mean - I'm confused; what part of the code doesn't already support that? I had it in my head that we couldn't, and now I can't see why I thought that.
This PR is ready in my opinion. Remaining things on my list include
It might be nice if you could check out BrownianInterval again for these things
utils.check_tensor_info
I'm now convinced that the min(..., ts[-1]) sol'n could work, if there's no arithmetic operation in min, which I don't think there would be. But there might be float -> torch.tensor conversions (or the other way around), which might lead to loss of numerical precision. But if ts is always a tensor, which we enforce now, I don't think there'd be an issue.
So I'm trying to figure out why Strat for all noise types isn't supported. That's what I had in my head + what we've discussed but I can no longer find where in adjoint_sde.py things go wrong? Maybe I'm just tired. (Except for
gdg_prod_default
not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)Let's do general noise for now. And it should be sufficient if you want to run certain experiments. I can do the small fixes for other noise types in another PR, but it would require quite some work, mostly regarding testing. At this moment, I'm more interesting in getting the basic ingredients working for the methods you guys proposed, as opposed to adding features just for completeness.
I'm not asking for them, dw. I just mean - I'm confused; what part of the code doesn't already support that? I had it in my head that we couldn't, and now I can't see why I thought that.
I think the set of methods in adjoint_sde.py
already supports additive and scalar; diagonal might need a little more thinking, but should be straightforward after I spend some time. The bulky thing that I want to leave to another PR is fixes to the tests and test problems. I would like to eventually have adjoint for Strat SDEs for different noise types numerically checked, and I'm a little worried that it might take quite some time. Atm, I don't want to block other people's work, if that make sense.
(Except for gdg_prod_default not yet being implemetned, which I think only excludes using the derivative-using Milstein method for now.)
Without the function for computing the dg g A term, you wouldn't be able to use adjoints with the new log-ODE scheme. So this will also be on the top of my priority list.
Alright, looks good to me. I'll hit the merge button.
Everything you've said makes sense, in particular the to-do list. I'll get my part of that done. (One extra thing for whichever of us gets to it first: BrownianTree also needs Levy area support.)
Opening this up is mainly to let you know this is in progress.
The remaining stuff:
Add gdg_jvp support.(Do this in another PR)Update: It's all done. Would appreciate comments in getting the code in better shape, as although I tried to think through carefully about most code, some parts were done hastily. Now all tests pass. @patrick-kidger
This PR is really long, I'd prefer not to block/be blocked by others' work. So I'm leaving gdg_jvp for adjoint in another PR.
A couple of random things that playing around with the new code has make me think about:
sdeint
andsdeint_adjoint
for small problems. I noticed this after running on some small examples with the new code.Update: Additional caveats: