Open MikeInnes opened 5 years ago
Can you spell out the difficulties regarding in-place updates? I don't see it as obvious.
One difficulty is that mutability is not part of the array API; there isn't an automatic way to discover whether an array is mutable without trying it and checking for an error. So that's where we'll need an ismutable
trait that effectively records a database of things we are allowed to mutate, and users with new array types will have to overload Flux.ismutable
(which isn't the end of the world, just kind of ugly).
The second difficulty is working out the semantic issues around mutability. If you have a ref around an immutable model, presumably "do this update in place" actually means update the ref. What if the model actually has a mix of immutable and mutable arrays?
Then we need to figure out how to expose this choice as an API, while also sharing mutating and non-mutating code as much as possible so that plugging in your own types is easy (not having to do everything twice). And that both for "leaf types" that get updated themselves and containers that only update their contents.
It's all doable, but quite a lot fiddlier than it initially looks, and will probably take some time to work out well.
So come from slack, I think it would be quite useful if we could move the optimizers to a single package (even share code with other things like Optim, but move to a package first maybe). We are currently using the Flux optimizer in Yao, but Flux itself is a quite heavy dependency for just Optimizers.
Now that https://github.com/SciML/ArrayInterface.jl#ismutablex exists, at least part of this is in place. WRT avoiding the "two copies of resnet in memory" problem, however, is there a more convenient API (both in terms of implementation and usage complexity) than the explicit param-passing that jax, haiku and co. use?
I'd like to bump this. I realize we recently needs to implement some of our own gradient based optimizer for Yao, I'm wondering if there are people interested in splitting the Optimise module out as a package? and perhaps with the new interface?
@Roger-luo you may be interested in https://github.com/FluxML/Optimisers.jl (bit of discussion in the issues) and https://github.com/FluxML/ML-Coordination-Tracker/discussions/22.
Should we close this issue? We have the proposed API via Optimisers.jl now which has a mechanism for handling in-place updates within an immutable API.
If we go full steam ahead with more functional-style AD (#628), then we'll need to rework optimisers a little. I think at the core the update step will look very similar, if a bit more functional, something like:
Then we have an
update
function which actually applies the gradient. In general, thex
,dx
andstate
can be structs and we'll rebuildx
by recursing over it (somewhat like mapleaves, though we won't need to depend ontreelike
since we can just use reflection here).I think this is all reasonably straightforward; the main wrinkle in my mind is how we enable in-place updates as an optimisation (since it's not ideal to have two copies of ResNet at once). I'm not aware of any great solution to this right now, so we might need to define an
ismutable
trait and declare it for types we care about.