LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
483 stars 58 forks source link

Externalize gradient computations to DifferentiationInterface.jl? #544

Open gdalle opened 6 months ago

gdalle commented 6 months ago

Hey there Avik!

As you may know, I have been busy developing DifferentiationInterface.jl, and it's really starting to take shape. I was wondering if it would be useful for Lux.jl as a dependency, in order to support a wider variety of autodiff backends defined by ADTypes.jl?

Looking at the code, it seems the main spot where AD comes up (beyond the docs and tutorials) is Lux.Training:

https://github.com/LuxDL/Lux.jl/blob/c27b9f5f57df2fcfa0bd1f6a6834b04f48d2ba61/src/contrib/training.jl#L93-L106

Gradients are only implemented in the extensions for Zygote and Tracker:

https://github.com/LuxDL/Lux.jl/blob/c27b9f5f57df2fcfa0bd1f6a6834b04f48d2ba61/ext/LuxZygoteExt.jl#L7-L14

https://github.com/LuxDL/Lux.jl/blob/c27b9f5f57df2fcfa0bd1f6a6834b04f48d2ba61/ext/LuxTrackerExt.jl#L25-L33

While DifferentiationInterface.jl is not yet ready or registered, it has a few niceties like Enzyme support which might pique your interest. I'm happy to discuss with you and see what other features you might need.

The main one I anticipate is compatibility with ComponentArrays.jl (https://github.com/gdalle/DifferentiationInterface.jl/issues/54), and I'll try to add it soon.

cc @adrhill

avik-pal commented 6 months ago

Yeah, I am all for purging that code, depending on DifferentiationInterface. The major things that need to work:

  1. ~Tracker~
  2. ComponentArrays (probably just testing)

That should be enough to get a start.

gdalle commented 6 months ago

Tracker works and is tested, ComponentArrays is my next target

avik-pal commented 6 months ago

That's great. This week is a bit busy but I can try possibly early next month

gdalle commented 6 months ago

No rush! And always happy to help debug. We're also chatting with Flux to figure out how best to support their use case, which is slightly more complex for lack of a dedicated parameter type like ComponentVector.

avik-pal commented 6 months ago

Do you require a vector input mandatorily? ComponentArrays has an overhead for smallish arrays (see https://github.com/LuxDL/Lux.jl/issues/49), so having an fmap based API might be good, though that is not really a big blocker.

gdalle commented 6 months ago

At the moment yes. We're thinking about how to be more flexible in order to accommodate Flux's needs, you can track https://github.com/gdalle/DifferentiationInterface.jl/issues/87 to see how it evolves

gdalle commented 5 months ago

DI v0.3 should be out sometime next week (I've been busy with sparse Jacobians & Hessians), but I don't think I'll have much time in the near future to revamp the Lux tests. Still, I think it would make sense to offer DI at least as a high level interface, even if it is not yet used in the package internals / tests. It might also help you figure out #605

gdalle commented 5 months ago

Note that for DI to work in full generality with ComponentArrays, I need https://github.com/jonniedie/ComponentArrays.jl/issues/254 to be fixed. Otherwise Jacobians and Hessians will stay broken (the rest, in particular gradient, is independent from stacking)

avik-pal commented 5 months ago

Yes I want to roll it out first as a high level interface when the inputs are AbstractArray rather than for testing. I still want to support arbitrary structured parameters as input for now.