Open gdalle opened 4 months ago
The answer will be different for Flux and Lux, the former most likely requiring support for Functor.jl.
It's worth noting that parameters in DL are somewhat of a mess: some are trainable arrays, some are trainable scalar values and other are non-trainable parameters, e.g. moving statistics in BatchNorm
layers.
It's worth noting that parameters in DL are somewhat of a mess: some are trainable arrays, some are trainable scalar values and other are non-trainable parameters, e.g. moving statistics in BatchNorm layers.
This is what Lux and its derivative frameworks were designed to fix. LuxCore.jl
is basically the interface specification that Lux models need to abide by. Lux simply specifies bindings to LuxLib
/NNlib
and some utilities users need for deep learning abiding by those specifications. To summarize the interface specification, you have 4 parts:
model
: This is an immutable structure of the problem (more specifically, neural network architecture). Values Never Participate in Gradient Calculationps
: Parameters. All of these are trainable. They may contain scalars but generally don't. When we want to compute the gradients, this is what we are typically computing the gradients for.st
: States (can think of them as non-trainable parameters). These can be anything like arrays, Vals, Numbers, etc. Values Never Participate in Gradient Calculationdata
: You might need to compute the gradients for this if users want it. But typically, this is not required.The answer will be different for Flux and Lux, the former most likely requiring support for Functor.jl
Lux uses a restrictive definition of fmap
(for operations on parameters) for type-stability (https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.recursive_map) but fmap
is the general solution for both.
I don't have a strong view on exactly what structs you should test, but I do know of several things that you will need to make decisions about, based on my experience helping @oxinabox design the tangent type system in ChainRules, and my experience with Tapir.jl.
Firstly, there are a couple of edge cases that you'll probably want to actively ensure that people avoid in order to reduce the number of tests you have to write:
Additionally, you'll need to consider whether to follow ChainRules' v1 approach and be flexible regarding what type is used to represent the tangent of a given struct of a given type, or whether to go down the route that Tapir.jl and Enzyme.jl take of insisting on there being a unique tangent type for each primal type. If you choose the former, you massively blow up the interface surface that you'll have to test. Moreover, you run the risk of different AD backends giving different answers and them both technically being "correct".
Personally, I would encourage you to take an opinionated view, and insist upon unique tangent types. I doubt it will matter too much what types you pick, but my experience is that being restrictive makes your life much easier.
I hope the above is helpful. I'm very excited to see what we wind up with here!
(Also, I'm on holiday at the minute, so I probably won't be super responsive to this thread until next week. Apologies in advance!)
Thanks for your advice! Before you go, I'd love it if you could take a look at my multi-argument/activity proposal in https://github.com/gdalle/DifferentiationInterface.jl/issues/311#issuecomment-2228066950, see if there are any obvious things we cannot do with it?
@gdalle did you ever think any more about this?
The release of 1.11 has prompted me to restart this discussion because Array
s are now generic structs. This is relevant because (once I've finished upgrading Mooncake for 1.11) if you take the gradient of a function w.r.t. an Array
, you should no longer expect to get an Array
back by default (you'll get a Mooncake.MutableTangent
, or something like that). I'm assuming that this wasn't a problem before because DI's tests all use Array
s(?).
While non-array like things are the correct thing to use internally in Mooncake, they're probably not what we want to be presenting to users. I'm keen to write some convenience functionality on my end to provide translations (for some types), but before doing that I would like to know what you would like in DI.
For example, I'm reasonably sure we would agree that an acceptable type for the gradient of a function w.r.t.
Vector{Float64}
is another Vector{Float64}
,Array{Float64, N}
is another Array{Float64, N}
,Float64
another Float64
,
but what about more generic types? e.g. Diagonal
, component arrays, etc. I can't see this formalised anywhere, so it would be good to agree on it. What happens if we have complicated element types in a given array?Maybe a useful exercise would be to define for some specific types what the type of the result ought to be, and to clearly state which set of types DI has strong opinions on, and which it does not yet have strong opinions on.
The goal of DI is to be as unopinionated as possible, so I probably won't be taking sides here. Think of DI as a fancy argument-passer, which returns whatever the backends return.
There have been endless discussions on the meaning of derivatives when you're on a manifold, and this meaning differs between backends. From what I understand, ChainRules tries to preserve structure while Enzyme takes a more cartesian approach, so there is no universally right answer. If I try to unify return types for structured objects, I will definitely make a lot of people unhappy, and probably trash performance in the process.
There are also differences on how every backend handles some fields in a struct. Some backends error on integers (ChainRules?), others just ignore them as inactive values (Enzyme?), others differentiate them fine (FiniteDiff?). Some backends even ignore numbers to differentiate only arrays (Tracker?).
Similarly, some backends accept arbitrary tangent types, while other backends (Enzyme and Mooncake) are stricter. For the stricter ones, I implement automatic conversion, but not automatic structure adaptation. In other words, if convert(correct_tangent_type, tangent)
fails, you're on your own.
DI is thoroughly tested with the standard Array
type, but the test suite is implemented with isapprox
, so there is no requirement to return the same type as the reference we compare against. You can even pass your own isapprox
function for structured outputs, if your test scenario has specific semantics (e.g. you want to ignore a subset of fields in the struct).
TLDR: Everything is in place to differentiate non-Array
s, but the semantics are up to the backend to decide.
Fair enough. In that case I'll ignore this issue until the upgrades are done, and figure out how to make everything work on the DI end when we get to it :)
Does your new tangent type behave like an array? Can one index it, sum it, etc.?
It almost certainly won't by default.
edit: I say "almost" because I'm not 100% sure what the best choice is from Mooncake's perspective yet.
Let's discuss it in https://github.com/compintell/Mooncake.jl/issues/286?
but what about more generic types? e.g. Diagonal, component arrays, etc. I can't see this formalised anywhere, so it would be good to agree on it. What happens if we have complicated element types in a given array?
I remember this has been discussed at length in ChainRules as well, although I couldn't find the relevant link. I think one outcome of those discussion is ProjectTo .
What kind of structs should we add to enable deep learning applications?