jekbradbury / Minibatch.jl

Other
5 stars 1 forks source link

Is this a compiler pass? #2

Open MikeInnes opened 6 years ago

MikeInnes commented 6 years ago

Coming from an ML background, it's natural to see minibatching as an operator-overloading / "gradient hacking" process; we try to coalesce operations, while also having masking semantics that manipulate how gradient information flows backwards. Whereas from a compiler perspective, this has nothing to do with gradients, it's just a SPMD transform that in principle can take any (possibly non-differentiable) code and run with it. (This only works if the pass doesn't have any semantics specific to the use case, of course, but I believe that's the case.)

The main operational difference is whether the pass runs before or after the AD, respectively. In principle, the compiler approach could be much more general here in that you don't have to worry about any derivative semantics; you just SPMD both the forward and backwards pass separately. The downside is that you must now work with a larger program that contains AD machinery (e.g. mutating tapes), so it's likely to be significantly harder to implement. A pure OO implementation is essentially equivalent to a naive AD, so it's pretty easy to do.

One thing that might swing it in favour of OO is the potential for optimisations when you can manipulate both passes together (e.g. injecting forward mode); though I could equally see optimisations coming from not having to worry about AD, too.

For my part, I need to play with this package more and figure out a representative example to guide thinking. I'd also like to understand how hard the general transformation is, which might refer to efforts like ISPC.

jekbradbury commented 6 years ago

Well, there’s no control flow in this version so it doesn’t have to be a compiler pass and makes sense as pure overloading, but with control flow in the picture there are (at least?) two choices:

  1. It’s a “true” compiler pass; either use a Cassette pass if we want to do a bit of manual type inference or go all the way and implement it on typed SSAIR.
  2. Eh, we can get rid of control flow before lowering with FunctionalControlFlow.jl; after that it’s just dispatch again and we can either use plain dispatch or Cassette contextual dispatch if plain dispatch is too ambiguous/hairy.

The problem as far as examples are concerned (or maybe the solution??) is that most use cases I’ve run into stress either the mask propagation stuff (what’s implemented here) or control flow but not both.

Other things that are relevant:

  1. I’m working with FB to port the Python version to C++ and include it as a pass in the PyTorch compiler (there’s an early PR but this is a rest-of-the-summer project for an FB intern)
  2. In my ideal world this would be integrated with an XLA/ONNX export story because both rely on carefully differentiating static and dynamic array dimensions. I think I want a tuple of something like this as a type parameter for some DLArray type (unions are probably banned, so maybe I’ll have to use a sentinel like -1):
    struct Dim
    dynamic::Bool
    size::Union{Int, Nothing}
    name::Symbol
    end
jekbradbury commented 6 years ago

I agree that it seems slightly easier to do batching before AD, which still leaves a few options:

  1. OO batching and OO AD: Batched{Tracked{Array}}
  2. OO batching and Cassette AD: @overdub DiffCtx f(x::Batched{Array})
  3. Cassette batching and Cassette AD: @overdub DiffCtx @overdub BatchCtx f(x::Array), if that’s even possible
  4. whatever other options exist using later compiler pass injection