Open Mikolaj opened 1 year ago
It's a bit stranger than I thought, but this is the best I've come up so far.
Source language: our Delta expressions.
Target language: terms in the target
language described below, with a distinguished variable c
in the environment. Target language terms t
are typed with a judgement Γ | c : σ ||- t : τ
. (Using ||-
just to visually distinguish from the standard non-linear judgement |-
, which is also used in some of the typing rules below.) These terms define an algebraically linear function σ -> τ
with a non-linear environment Γ
.
Given a source term Γ |- t : τ
, transposition produces in the target language: Γ | c : τ ||- transp[t] : DMap
. The DMap
, Variable
, Array
and Shape
types don't have further type indices in the typing rules below because I feared madness otherwise. There's probably a way to include those and make this type-safe.
# LINEARITY
------------------- (zero)
Γ | c : τ ||- 0 : σ
Γ | c : τ ||- s : σ Γ | c : τ ||- t : σ
-------------------------------------------- (plus)
Γ | c : τ ||- s + t : σ
Γ |- r : Real Γ | c : τ ||- t : σ
------------------------------------- (scale)
Γ | c : τ ||- r * t : σ
# LINEAR BINDING
Γ | c : τ ||- s : σ₁ Γ | c : σ₁ ||- t : σ₂
----------------------------------------------- (let)
Γ | c : τ ||- let c = s in t : σ₂
------------------- (zero)
Γ | c : τ ||- c : τ
# COTANGENT MAPS
n Variable
------------------------------- (onehot map)
Γ | c : τ ||- Onehot n c : DMap
Γ | c : τ ||- t : DMap n Variable
-------------------------------------- (map delete)
Γ | c : τ ||- Delete n t : DMap
Γ | c : τ ||- t : DMap n Variable
-------------------------------------- (map lookup)
Γ | c : τ ||- Lookup n t : DMap
# ARRAY OPERATIONS
Γ |- sh : Shape Γ |- f : Int -> Int
Γ | c : τ ||- t : Array
----------------------------------------- (scatter)
Γ | c : τ ||- Scatter sh f t : Array
Example transposition rules (rewritten from eval
in the original POPL paper, and gather from simplified/HordeAd/Core/Delta.hs:buildFinMaps:evalR):
T[Zero] = 0
T[Scale y u] = let c = y * c in T[u]
T[Add u₁ u₂] = T[u₁] + T[u₂]
T[Var n] = Onehot n c
T[Let n u₁ u₂] = let c = T[u₂] in Delete n c + (let c = Lookup n c in T[u₁])
-- Note that the first c binding above is of type DMap.
T[GatherZ sh u f sha] = let c = Scatter sha c f in T[u]
-- I would rather write 'GatherZ sh f u sha' (or even 'GatherZ sh f u') and
-- 'Scatter sha f c', but this is horde-ad argument order. :D
While the above is a fine linear language, I think, the point of this was to Cayley-transform DMap
, so we'd need to specialise 0
and +
to only work on DMap
values, i.e.
------------------- (zero)
Γ | c : τ ||- 0 : DMap
Γ | c : τ ||- s : DMap Γ | c : τ ||- t : DMap
-------------------------------------------------- (plus)
Γ | c : τ ||- s + t : DMap
after which I think we can without issues implement DMap
actually as Endo DMap ~= DMap -> DMap
with the standard monoid operations id
and (.)
. We'd have to double-check that Delete
and Lookup
continue to be implementable in the presence of let
after DMap
is Cayley-transformed. I think that works out, but I'm not 100% sure.
Hah, yes, this is unexpected.
Re GatherZ sh f u'
, IIRC the extra shape is needed to compute the forward derivative from a delta expression without traversing the delta expression to reconstruct the shape. The u
in front of f
may be related to simplification, vectorization, etc., recursing over u
rather than over f
or to indexing taking first the term, only then the index. These are weak reasons.
Here's the Tom's idea from https://github.com/Mikolaj/horde-ad/issues/95#issuecomment-1503013877:
As far as I understand so far, we'd need more
Delta
constructors and we'd transpose by rewriting theDelta
expressions. Afterwards, to get rid of theDelta
constructors, we'd interpret them straightforwardly as linear transformation syntax, which is implemented (and probably bit-rotted) inbuildDerivative
, which computes forward derivatives on the basis ofDelta
collected in forward pass. In the pipeline that produces gradient Ast, we'd runbuildDerivative
withAst
codomain.I wonder how much the extended
Delta
would start resemblingAst
. If we switched fromDelta
toAst
, we'd end up withAst
nested insideAst
, which would be fine in this case. But perhapsDelta
has some advantage overAst
? I guess it's linear, just as the linear sublanguage in the YOLO paper. I wonder if it's easy to extendDelta
to be enough for this term rewriting, but still syntactically guaranteeing linearity (assuming it does ATM, after all the tensor extensions).A modest version of this rewrite is to evaluate
Delta
to combinators of typeDeltaMap -> DeltaMap
, but not reify the combinators as constructors ofDelta
. Then, this can be presented as a rewrite, but implemented with immediate evaluation of the rewritten terms. Relevant parts ofbuildDerivative
are then manually inlined.