Open lockwo opened 3 months ago
The difficulty is that backprop via RecursiveCheckpointAdjoint
involves recomputing intermedaite states from checkpoints. I'm not confident that this will be bitwise identical to the original forward pass -- e.g. perhaps due to nondeterminism in convolutions.
Meanwhile, UBP casts float->int->key and uses that to sample its random noise. This means that floating point fluctuations will produce entirely different Brownian motion samples.
I guess it seems to me that these are solvable problems in engineering (I hope), but there isn't any theoretical limitations (like if we want to just store every sample made on the forward pass, and then on the backward pass reuse these samples/use a bridge because we have a lot of memory since neural networks aren't in play, that seems possible maybe not with a UBP, but with something simpler than a VBT). @frankschae can help articulate this better, and why we are interested in this direction (the motivation is basically, VBT is potentially overkill for us and might introduce slowdowns we don't need, if there is some distinction between UBP, VBT, and some object that does sampling differentiably, has some misc. stuff for weak solvers we are doing, and has bridges then that third object might be the most interesting).
Yup, this all makes sense.
If our controls consumed a step counter or something then one possible approach would just be to jr.fold_in
the step index into a key, and that should operate reliably. Or if controls were stateful then we could jr.split
a key each step. The current approach does what it does just because right now, a control consumes only the times at which it is evaluated.
I think this dovetails with the VBT discussion in #489 -- maybe we should think about modifying the way controls are handled, and if we pick our abstractions right we can tackle this issue whilst we're at it.
I think in general, having some concept of state makes sense for general controls. Even if all existing controls can be done without it, we are working on more flexible controls that might take advantage of that (and it would allow more implementations of custom controls in an easier way). That being said, exactly how to do it is open. I think once we have a weak solver PR open to motivate this, it can provide more concrete examples.
As we've explored more, I think stateful controls can make a lot of sense and would be useful. For the limited usage we currently have, just leeching off the solver state and adding an argument to UBP is sufficient (see: https://github.com/lockwo/diffrax_extensions/blob/Owen/new-weak-internal/diffrax/_brownian/path.py#L114). A more robust and mainline implementation would be beneficial for a couple reasons:
However, as I went about a draft implementation, one possible concern is that it would be a breaking change for a (maybe small) class of users (or it would require a some back checking). In the most abstract form, my approach was adding an argument to the AbstractPath evaluate and an init, then calling the init in integrate with all the others. For anyone using UBP or VBT, this is totally invisible, since in the default workflow, the user never actually calls evaluate on these classes (and even with the new changes, since their states are optional = None, this wouldn't be breaking). However, if you have a subclassed path, this would break (a possible fix would be to monkeypatch or wrap the provided class to just have nothing inits and accept None states, but if it worked it would still be breaking from a developer perspective since they aren't adhering to the abstractmethod specifications). Curious if you had any thoughts on if this approach (or if I am missing a non-breaking way to do this).
So I think this is probably too dangerous to implement in the main library itself. Right now integration happens with respect to the same path regardless of step size controller, and having a special case where that's not the case is definitely a footgun.
Fortunately, via the custom solver API you can write something that does this regardless!
Something like this should work:
class MyControl(AbstractPath):
def evaluate(..., control_state):
...
class CallToEvaluate(AbstractPath):
fn: Callable
def evaluate(self, ...):
return self.fn(...)
def is_control(x):
return isinstance(x, MyControl)
def bind_control_state(terms, control_state):
def _bind_control_state(x):
if is_control(x):
return CallToEvaluate(functools.partial(x.evaluate, control_state=control_state))
else:
return x
return jtu.tree_map(_bind_control_state, terms, is_leaf=is_control)
class MySolver(AbstractWrappedSolver):
def step(...):
solver_state, control_state = solver_state
terms = bind_control_state(terms, control_state)
... = self.solver.step(...)
control_state = self.step_control_state(control_state)
solver_state = (solver_state, control_state)
return ...
Which is arguably the appropriate amount of fiddly for doing something questionable like this :)))
Yea, that's basically the implementation we have currently.
Which is arguably the appropriate amount of fiddly for doing something questionable like this :)))
I do agree, which is why we were looking into supporting less questionable ways. Although since it's too dangerous, we'll just formalize it more and put it in our extensions.
In the docs it says "You do not need to backpropagate through the differential equation." for UBP usage. However, this doesn't seem to be theoretically necessary, you can just backprop through the solver with the added noise right? What's the motivation requiring this to be the case?
It says "Internally this operates by just sampling a fresh normal random variable over every interval, ignoring the correlation between samples exhibited in true Brownian motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.)" which makes sense inre adaptivity (since you need a brownian bridge or something of the like), but not for differentiation.