Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
34 stars 6 forks source link

Represent transpose of Delta as rewrite rules #100

Open Mikolaj opened 1 year ago

Mikolaj commented 1 year ago

Here's the Tom's idea from https://github.com/Mikolaj/horde-ad/issues/95#issuecomment-1503013877:

the transposition can't be expressed as rewriting rules, because it's stateful.

It's only stateful because it's Cayley-transformed. The eval function:

eval :: R -> Delta -> DeltaMap -> DeltaMap

is really eval :: R -> Delta -> Endo DeltaMap: its codomain is Cayley-transformed (think DList difference lists) in order to make things more efficient. (addDelta becomes a bit more expensive this way, because it has to update a value in the map instead of being able to create a singleton map, but addition of maps is much cheaper because surely (.) is more efficient than Map.union.)

So maybe it can be expressed using rewrite rules, but just into the language of Endo DeltaMap, not of DeltaMap?

As far as I understand so far, we'd need more Delta constructors and we'd transpose by rewriting the Delta expressions. Afterwards, to get rid of the Delta constructors, we'd interpret them straightforwardly as linear transformation syntax, which is implemented (and probably bit-rotted) in buildDerivative, which computes forward derivatives on the basis of Delta collected in forward pass. In the pipeline that produces gradient Ast, we'd run buildDerivative with Ast codomain.

I wonder how much the extended Delta would start resembling Ast. If we switched from Delta to Ast, we'd end up with Ast nested inside Ast, which would be fine in this case. But perhaps Delta has some advantage over Ast? I guess it's linear, just as the linear sublanguage in the YOLO paper. I wonder if it's easy to extend Delta to be enough for this term rewriting, but still syntactically guaranteeing linearity (assuming it does ATM, after all the tensor extensions).

A modest version of this rewrite is to evaluate Delta to combinators of type DeltaMap -> DeltaMap, but not reify the combinators as constructors of Delta. Then, this can be presented as a rewrite, but implemented with immediate evaluation of the rewritten terms. Relevant parts of buildDerivative are then manually inlined.

tomsmeding commented 1 year ago

It's a bit stranger than I thought, but this is the best I've come up so far.

Source language: our Delta expressions. Target language: terms in the target language described below, with a distinguished variable c in the environment. Target language terms t are typed with a judgement Γ | c : σ ||- t : τ. (Using ||- just to visually distinguish from the standard non-linear judgement |-, which is also used in some of the typing rules below.) These terms define an algebraically linear function σ -> τ with a non-linear environment Γ.

Given a source term Γ |- t : τ, transposition produces in the target language: Γ | c : τ ||- transp[t] : DMap. The DMap, Variable, Array and Shape types don't have further type indices in the typing rules below because I feared madness otherwise. There's probably a way to include those and make this type-safe.

# LINEARITY

------------------- (zero)
Γ | c : τ ||- 0 : σ

Γ | c : τ ||- s : σ      Γ | c : τ ||- t : σ
-------------------------------------------- (plus)
         Γ | c : τ ||- s + t : σ

Γ |- r : Real     Γ | c : τ ||- t : σ
------------------------------------- (scale)
       Γ | c : τ ||- r * t : σ

# LINEAR BINDING

Γ | c : τ ||- s : σ₁      Γ | c : σ₁ ||- t : σ₂
----------------------------------------------- (let)
      Γ | c : τ ||- let c = s in t : σ₂

------------------- (zero)
Γ | c : τ ||- c : τ

# COTANGENT MAPS

        n Variable
------------------------------- (onehot map)
Γ | c : τ ||- Onehot n c : DMap

Γ | c : τ ||- t : DMap      n Variable
-------------------------------------- (map delete)
   Γ | c : τ ||- Delete n t : DMap

Γ | c : τ ||- t : DMap      n Variable
-------------------------------------- (map lookup)
   Γ | c : τ ||- Lookup n t : DMap

# ARRAY OPERATIONS

Γ |- sh : Shape       Γ |- f : Int -> Int
        Γ | c : τ ||- t : Array
----------------------------------------- (scatter)
   Γ | c : τ ||- Scatter sh f t : Array

Example transposition rules (rewritten from eval in the original POPL paper, and gather from simplified/HordeAd/Core/Delta.hs:buildFinMaps:evalR):

T[Zero] = 0
T[Scale y u] = let c = y * c in T[u]
T[Add u₁ u₂] = T[u₁] + T[u₂]
T[Var n] = Onehot n c
T[Let n u₁ u₂] = let c = T[u₂] in Delete n c + (let c = Lookup n c in T[u₁])
  -- Note that the first c binding above is of type DMap.
T[GatherZ sh u f sha] = let c = Scatter sha c f in T[u]
  -- I would rather write 'GatherZ sh f u sha' (or even 'GatherZ sh f u') and
  -- 'Scatter sha f c', but this is horde-ad argument order. :D
tomsmeding commented 1 year ago

While the above is a fine linear language, I think, the point of this was to Cayley-transform DMap, so we'd need to specialise 0 and + to only work on DMap values, i.e.

------------------- (zero)
Γ | c : τ ||- 0 : DMap

Γ | c : τ ||- s : DMap      Γ | c : τ ||- t : DMap
-------------------------------------------------- (plus)
           Γ | c : τ ||- s + t : DMap

after which I think we can without issues implement DMap actually as Endo DMap ~= DMap -> DMap with the standard monoid operations id and (.). We'd have to double-check that Delete and Lookup continue to be implementable in the presence of let after DMap is Cayley-transformed. I think that works out, but I'm not 100% sure.

Mikolaj commented 1 year ago

Hah, yes, this is unexpected.

Re GatherZ sh f u', IIRC the extra shape is needed to compute the forward derivative from a delta expression without traversing the delta expression to reconstruct the shape. The u in front of f may be related to simplification, vectorization, etc., recursing over u rather than over f or to indexing taking first the term, only then the index. These are weak reasons.