microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Structured prims #824

Closed toelli-msft closed 3 years ago

toelli-msft commented 3 years ago

The immediate problem that we are trying to solve is https://github.com/microsoft/knossos-ksc/pull/766#issuecomment-846851817, that is, prims must be annotated with their base argument type in order for their SUF reverse pass return type to be deduced. Userfuns are already annotated with their base type. Removing this distinction between userfuns and prims actually leads to a modest simplification of the codebase.

The changes all arise from a change in Lang.hs to the data types we use to represent function identifiers. This PR enacts the below and everything else follows from it:

data DerivedFun funname p = Fun Derivations (BaseFunId funname p)

-- DerivedFun has just two instantiations
--
-- Fun p: these you can call hence are in the function field of a Call (in
-- the guise of a TFun)
--
-- UserFun p: These you can def/edef and hence are the domain of the
-- GblSymTab and CST and appear in the def_fun field of DefX.
type UserFun p = DerivedFun String p
type Fun     p = DerivedFun BaseName p

data BaseFunId name (p :: Phase) = BaseFunId name (BaseArgTy p)

data BaseName = BaseUserFunName String   -- BaseUserFuns have a Def
              | BasePrimFunName PrimFun  -- PrimFuns do not have a Def
              deriving (Eq, Ord, Show)

type family BaseArgTy p where
  BaseArgTy Parsed   = Maybe Type
  BaseArgTy OccAnald = Type
  BaseArgTy Typed    = Type

That is, the BaseArgTy (formerly called BaseUserFunArgTy) is part of the identity of the base function, regardless of whether the base function is a userfun or a prim. Removing this distinction between userfuns and prims actually leads to a simplification of other parts of the code. There is less need for lensy stuff.

Remaining questions

What next?

This will allow https://github.com/microsoft/knossos-ksc/pull/766 to be merged.

simonpj commented 3 years ago

An alternative approach: kill PrimFun altogether, and generate an edef for each instance of our current primitives. That would solve the problem that this PR addresses, and remove 400 lines of code from Prim.hs (plus more elsewhere). The cost would be that of generating a raft of edefs.

When my polymorphism patch arrives, we could perhaps kill off some of those edefs again. Maybe that would up-grade the priority landing the polymorphism patch.

awf commented 3 years ago

So, build, if in scope:

(edef build (Vec (Tuple String (Tensor 2 Float)))
         ((n : Integer) (f : Lam Integer (Tuple String (Tensor 2 Float)))))
(edef build (Tensor 2 (Tuple String (Tensor 2 Float)))
         ((n : Tuple Integer Integer) (f : Lam (Tuple Integer Integer) (Tuple String (Tensor 2 Float)))))

(edef elim (Tuple) Vec (Tuple String (Tensor 2 Float)))
(edef [suffwdpass [elim (Vec (Tuple String (Tensor 2 Float))))]] 
         (Tuple (Tuple) BOG)
         (Vec (Tuple String (Tensor 2 Float))))
(edef [sufrevpass [elim (Vec (Tuple String (Tensor 2 Float))))]] 
        (Vec (Tuple (Tuple) (Tensor 2 Float)))
        (b : BOG, dr : (Tuple))

Others

(edef constVec (Vec (Tuple String Float)) ((n : Integer) (v : Tuple String Float)))

(edef lmAdd (LM S T) ((a : LM S T) (b : LM S T))
awf commented 3 years ago

Super PR intro, thanks! (i.e. the opening comment https://github.com/microsoft/knossos-ksc/pull/824#issue-652310415)

dcrc2 commented 3 years ago

There's one thing I'm unsure about here, but I agree that the changes to the data structures are a good simplification, and that this will solve the problem found in #766.

What I'm not yet convinced by is writing prims as structured names in the output .kso. Firstly this is a breaking change so if we do go this way we'll have to at least open an issue for ksc-mlir to make it compatible. But I'd argue for writing prims as non-structured names in any case, because:

(To clarify: I'm not suggesting any changes to the data structures proposed by this PR. I'd just like to change the pretty-printing to render prims as non-structured names. If we have to print a derived function of a prim, then that would need to be a structured name in order to disambiguate, but derived-functions-of-prims should only ever be seen during intermediate stages of ksc compilation, never in the final output.)

(Edited to add 4th point.)

toelli-msft commented 3 years ago

@dcrc2 I think what you're saying boils down to "This functionality doesn't require any changes to be visible outside ksc". I agree, so I have removed printing and parsing of prim structured names.

I think this is now ready so shall merge imminently unless anyone shouts.

toelli-msft commented 3 years ago

Testing

awf commented 3 years ago

lgtm