Open gdalle opened 6 months ago
Yeah, I am all for purging that code, depending on DifferentiationInterface
. The major things that need to work:
That should be enough to get a start.
Tracker works and is tested, ComponentArrays is my next target
That's great. This week is a bit busy but I can try possibly early next month
No rush! And always happy to help debug. We're also chatting with Flux to figure out how best to support their use case, which is slightly more complex for lack of a dedicated parameter type like ComponentVector.
Do you require a vector input mandatorily? ComponentArrays has an overhead for smallish arrays (see https://github.com/LuxDL/Lux.jl/issues/49), so having an fmap
based API might be good, though that is not really a big blocker.
At the moment yes. We're thinking about how to be more flexible in order to accommodate Flux's needs, you can track https://github.com/gdalle/DifferentiationInterface.jl/issues/87 to see how it evolves
DI v0.3 should be out sometime next week (I've been busy with sparse Jacobians & Hessians), but I don't think I'll have much time in the near future to revamp the Lux tests. Still, I think it would make sense to offer DI at least as a high level interface, even if it is not yet used in the package internals / tests. It might also help you figure out #605
Note that for DI to work in full generality with ComponentArrays, I need https://github.com/jonniedie/ComponentArrays.jl/issues/254 to be fixed. Otherwise Jacobians and Hessians will stay broken (the rest, in particular gradient, is independent from stacking)
Yes I want to roll it out first as a high level interface when the inputs are AbstractArray
rather than for testing. I still want to support arbitrary structured parameters as input for now.
Hey there Avik!
As you may know, I have been busy developing DifferentiationInterface.jl, and it's really starting to take shape. I was wondering if it would be useful for Lux.jl as a dependency, in order to support a wider variety of autodiff backends defined by ADTypes.jl?
Looking at the code, it seems the main spot where AD comes up (beyond the docs and tutorials) is
Lux.Training
:https://github.com/LuxDL/Lux.jl/blob/c27b9f5f57df2fcfa0bd1f6a6834b04f48d2ba61/src/contrib/training.jl#L93-L106
Gradients are only implemented in the extensions for Zygote and Tracker:
https://github.com/LuxDL/Lux.jl/blob/c27b9f5f57df2fcfa0bd1f6a6834b04f48d2ba61/ext/LuxZygoteExt.jl#L7-L14
https://github.com/LuxDL/Lux.jl/blob/c27b9f5f57df2fcfa0bd1f6a6834b04f48d2ba61/ext/LuxTrackerExt.jl#L25-L33
While DifferentiationInterface.jl is not yet ready or registered, it has a few niceties like Enzyme support which might pique your interest. I'm happy to discuss with you and see what other features you might need.
The main one I anticipate is compatibility with ComponentArrays.jl (https://github.com/gdalle/DifferentiationInterface.jl/issues/54), and I'll try to add it soon.
cc @adrhill