Open MikeInnes opened 6 years ago
Well, there’s no control flow in this version so it doesn’t have to be a compiler pass and makes sense as pure overloading, but with control flow in the picture there are (at least?) two choices:
The problem as far as examples are concerned (or maybe the solution??) is that most use cases I’ve run into stress either the mask propagation stuff (what’s implemented here) or control flow but not both.
Other things that are relevant:
DLArray
type (unions are probably banned, so maybe I’ll have to use a sentinel like -1):
struct Dim
dynamic::Bool
size::Union{Int, Nothing}
name::Symbol
end
I agree that it seems slightly easier to do batching before AD, which still leaves a few options:
Batched{Tracked{Array}}
@overdub DiffCtx f(x::Batched{Array})
@overdub DiffCtx @overdub BatchCtx f(x::Array)
, if that’s even possible
Coming from an ML background, it's natural to see minibatching as an operator-overloading / "gradient hacking" process; we try to coalesce operations, while also having masking semantics that manipulate how gradient information flows backwards. Whereas from a compiler perspective, this has nothing to do with gradients, it's just a SPMD transform that in principle can take any (possibly non-differentiable) code and run with it. (This only works if the pass doesn't have any semantics specific to the use case, of course, but I believe that's the case.)
The main operational difference is whether the pass runs before or after the AD, respectively. In principle, the compiler approach could be much more general here in that you don't have to worry about any derivative semantics; you just SPMD both the forward and backwards pass separately. The downside is that you must now work with a larger program that contains AD machinery (e.g. mutating tapes), so it's likely to be significantly harder to implement. A pure OO implementation is essentially equivalent to a naive AD, so it's pretty easy to do.
One thing that might swing it in favour of OO is the potential for optimisations when you can manipulate both passes together (e.g. injecting forward mode); though I could equally see optimisations coming from not having to worry about AD, too.
For my part, I need to play with this package more and figure out a representative example to guide thinking. I'd also like to understand how hard the general transformation is, which might refer to efforts like ISPC.