microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Reducing size of elim BOG #689

Open toelli-msft opened 3 years ago

toelli-msft commented 3 years ago

We have a prim elim :: A -> (). Its SUF/BOG-AD reverse pass is [sufrevpass elim] :: ((), BOG) -> dA, where the return value is the zero of type dA. We are free to choose BOG as long as as the zero of type dA can be constructed from it. The easiest choice for BOG is mkTangentZero of the argument that elim was supplied with, that is

[suffwdpass elim] a = (a, mkTangentZero a)
[sufrevpass elim] ((), bog) = bog

But when A = Vec Float this BOG is much too big! If a is a Vec Float of size n then all we need to recover mkTangentZero a is n. Much better! Can we use this strategy to make all BOGs for elim small? Unfortunately not! If r : Vec (Vec Float) were ragged then the BOG would need to at least be a Vec Integer describing the length of each nested vector.

Still, we can take advantage of this strategy in some special cases. There is still a difficulty though. If we choose Integer for the BOG of elim on Vec Float then what do we choose for the BOG of elim on Vec (Float, Float) etc.? Given that [sufrevpass elim] :: ((), BOG) -> dA we need to be able to determine dA from BOG. Some possibilities:

  1. Come up with an ad hoc encoding, e.g. the BOG of Vec Float could be ((), Integer) and the BOG of Vec (Float, Float) could be (((), ()), Integer). I don't hold out much hope of this possibility being feasible to work with in the general case.

  2. Use Dummy for tagging values. However, Dummy is not erased at run time so it will incur a performance penalty (perhaps slight). (N.B. we can't erase Dummy at the moment because it is used to fake up sum types from tuples. If we added real sum types then we could use Dummy for this purpose.)

  3. Introduce a new type, like Dummy, but that is actually erased at run time.

  4. Introduce structuring to the names of prims. elim : A -> () could become [elim A] : A -> (), [sufrevpass elim] : ((), BOG) -> dA would become [sufrevpass [elim A]] : ((), BOG) -> dA. The problem would then vanish. (Perhaps this is the natural solution. Solving this problem for userfuns was why it was essential to introduce structured names.)

Thoughts @awf?

awf commented 3 years ago

For 1/2/3: I think you mean something like this.

Instead of returning return
constVec n val (n, val)
build n (i -> constVec i (f i)) build n (i -> (i, f i))

And the concern is that the RHS type might not uniquely define the LHS activity? Essentially we need to deal with the cases in getZero, and the encoding needs to be sufficient to distinguish those cases. So it's a finite number of cases, we should check.

Can we use this strategy to make all BOGs for elim small? Unfortunately not! If r : Vec (Vec Float) were ragged then the BOG would need to at least be a Vec Integer describing the length of each nested vector.

Yes, this is fine -- this is what shape has to do anyway. It's still smaller than the Vec (Vec Float) it replaces.

2/3. We can probably eliminate the perf penalty for dummy in a lot of cases.

  1. Yes, I can't see a huge downside here -- elims can probably be aggressively inlined after SUF is used for AD, or could be elided anyway in a conversion "out of SUF" before emitting a KSO. If we believe SUF can also be used for object lifetime management, it can be re-run after the inlining to see what's left.
awf commented 3 years ago

So, yes: 4 seems worth looking at. It may be that it ends up having to be "some" prims only, in which case we may as well start with a list of prims that get monomorphised, where that list contains only elim to begin with.

simonpj commented 3 years ago

I am still wondering whether, instead of thinking about cunning ways to represent zero, we could explore the functional representation, in which zero is the identity function and addition is function composition. If successful that would render all these questions moot.

toelli-msft commented 3 years ago

To expand a bit, for every type T we have to choose

such that elimBog (bogZero t) is the zero of the same shape as t. Furthermore, unless we take the step of adding a structured type to the prim elim then we have the additional requirement

There are two things we'd like to achieve with this choice

For example, when T = Vec Float we can make the choice

but this wastes space. A choice with smaller BOG yet that runs in the same time is

The question at hand is which choice we should make for ElimBOG. elimBOG and bogZero.

@awf's previous comment

this is what shape has to do anyway

makes me think that this question is exactly a question of shape (except perhaps the additional requirement around coincidence). I will look into the shapes stuff to see if it already solves this problem.

toelli-msft commented 3 years ago

@simonpj suggests

we could explore the functional representation, in which zero is the identity function and addition is function composition

I shall write up more on that when I get a moment.

toelli-msft commented 3 years ago

I think that behind @simonpj's suggestion lies the idea that, regardless of how we store them in the bog, we shouldn't be creating and adding large arrays of zeroes to things in any case. There are two broad approaches we could take:

  1. Thread large arrays and use update-in-place to avoid the need to create zeros
  2. Use a representation of the tangent space which (effectively) allows tagged zeroes, as well as possibly other goodies.

@simonpj and I discussed 2. We concluded that using a functional representation T -> T or dT -> dT for the tangent space doesn't help. To see why, consider

constVec : (Int, Float) -> Vec Float

Its reverse pass under a functional representation would be something like

sufrevpass$constVec (Vec Float -> Vec Float, BOG) -> (Int, Float) -> (Int, Float)

How can we make use of the Vec Float -> Vec Float that is passed in when we have no Vec Float to apply it to? Well, there is one special value of Vec Float available to us, the zero vector. But now it seems we've just moved the problem of creating zero vectors elsewhere.

There is an alternative choice of tangent space. We could come up with a sum type representing all operations that can be applied to a tangent space, for example

data Delta a where
    Add :: Delta a -> Delta a -> Delta a
    Zero :: Delta a
    Leaf :: Float -> Delta Float
    IncAt :: Int -> Delta a -> Delta (Vec a)
    DFst :: Delta a -> Delta (a,b)
    DSnd :: Delta b -> Delta (a,b)

But this poses more problems:

  1. KSC doesn't have sum types let alone GADTs
  2. This is effectively a "sparse" tangent space representation. How do we stop it getting too large? When a large number of IncAts (or Leafs) have accumulated we ideally want to accumulate them into a dense representation somehow.
  3. Are we confident that this representation can be differentiated a second time and, moreover, efficiently?
dcrc2 commented 3 years ago

To expand a bit, for every type T we have to choose

  • a type ElimBOG{T}
  • a function elimBog : T -> ElimBOG{T}
  • a function bogZero : ElimBOG{T} -> dT

such that bogZero (elimBog t) is the zero of the same shape as t.

As you've hinted above, I think what you've presented here is precisely the definition of ShapeType{T}, with elimBog being the existing primitive function shape. (bogZero is a function which doesn't currently exist in ksc, though it has an implementation in knossos.h for use in buildFromSparse). Indeed, if shapes are to be useful, it seems important that ShapeType{T} is the most efficient representation possible, so if ShapeType{T} seems suboptimal, we should really fix that by changing ShapeType{T}, rather than by introducing a separate ElimBOG{T}.

Currently ShapeType{Vec Float} is Vec (Tuple), and this may not be optimal. As we were discussing yesterday, we could optimize this in the backend, by having a special representation for Vec (Tuple) which was as efficient as an Integer. (In Cgen this could be a template specialization for tensor<N, tuple<>>.) But since we'd have to do this optimization in every backend (and RLO would have to know about it in order to calculate the cost correctly), it might be better to change the ks definition instead.

(But I also agree that it would be a lot better to avoid creating zeros at all, if that's possible.)

awf commented 3 years ago

It feels fine to have the shape rules include a special case for Vec Scalar:

ShapeType{Int} = Tuple []
ShapeType{Vec Float} = Int
ShapeType{Vec Int} = Int
ShapeType{Tuple [T1 .. Tn]} = Tuple [ShapeType{T1} .. ShapeType{Tn}]
ShapeType{Vec T} = Vec ShapeType{T}

We don't require that T be inferrable from ShapeType{T}, as we can always arrange to have T nearby.

If we haven't already, let's make a doc somewhere describing ShapeType{T} -- either in a Haskell Note, or perhaps Shapes.md beside Shapes.hs, or in doc/sphinx/Transforms/Shape.rst.

And in that doc, we should write the invariants we want. I believe one invariant might be

S{T} = ShapeType{T} is such that from a value of S{T}, and T, a zero object can be constructed.

Or

mkZero@T (shape z) == z when z is a zero of type t.

I would also really like a doc for the other transformations that is not AD.pdf

dcrc2 commented 3 years ago

On prims and structured names: I think we should be regarding elim as an implementation detail of AD, so that the answer to this question shouldn't actually matter outside of ksc.

That is, elim can be a primitive function in ksc -- but it will never appear in the .kso, and Cgen and other backends won't need to know how to generate code for it. (Similarly with getZero or [sufrevpass elim]: these can always be rewritten in terms of literal zeros and constVec.)

This would be similar to how we deal with constructs like [rev constVec] at the moment: this can arise in unoptimized AD, but is always rewritten to a form which uses base primitives. Cgen does not know how to generate code for any derived functions whose base is a primitive. I think this is a good thing, and we should be minimizing the number of primitives that backends need to know about.

If elim doesn't appear in the .kso then I think the simplest answer then becomes:

This would mean that the .kso is unchanged from what we have at the moment, while still solving the disambiguation problem.

Alternatively if you want to write out all prims as structured names, this will need some modification to our other parsers; but they would be free to ignore the type information so the changes required would be fairly minimal.

awf commented 3 years ago

Alternatively if you want to write out all prims as structured names, this will need some modification to our other parsers; but they would be free to ignore the type information so the changes required would be fairly minimal.

Exactly -- I don't see this as a real problem, and it would maybe be better than having more special cases around prims.

Let's recall:

  1. Some prims (eq, ts_add, build, fold, map) are just function templates -- we would expect to modify KSC to emit edefs for each used instance. Those edefs would obtain structured names in the normal way (either literally emitted, or trivially assembled at kso parse time).
  2. A few prims are very special, e.g. get$1 etc, and those can be special-cased.
dcrc2 commented 3 years ago

Alternatively if you want to write out all prims as structured names, this will need some modification to our other parsers; but they would be free to ignore the type information so the changes required would be fairly minimal.

I was thinking of things like ts_add when I wrote this; but I'm not so sure it makes sense for the "very special" primitives like build and get$1. So it feels like we're always going to end up with two kinds of prims:

And the large majority of prims could reasonably belong to either category. But I'd have a preference for saying that anything that the backends need to understand should be in the first category; any primitives in the second category should be rewritten to other forms by ksc and not appear in the .kso. Then the backends only have one kind of prim to worry about.

toelli-msft commented 3 years ago

That is, elim can be a primitive function in ksc -- but it will never appear in the .kso, and Cgen and other backends won't need to know how to generate code for it. (Similarly with getZero or [sufrevpass elim]: these can always be rewritten in terms of literal zeros and constVec.)

Are you proposing a subdivision of prims into those that can appear in a .kso and those that can't? How do we make that distinction? In some sense ts_scale is an implementation detail of AD too. How should we determine whether it is allowed to appear in a .kso?

toelli-msft commented 3 years ago

Seems like we're not the only ones having problems with elim! šŸ˜„

https://github.com/microsoft/knossos-ksc/pull/645/files#diff-a533bc14555235deb5212517ca30ee150d8d71f268b2c31a0b541f34f65ccfaaR10-R20

awf commented 3 years ago

Seems like we're not the only ones having problems with elim! šŸ˜„

https://github.com/microsoft/knossos-ksc/pull/645/files#diff-a533bc14555235deb5212517ca30ee150d8d71f268b2c31a0b541f34f65ccfaaR10-R20

Sorry - I just broke your link, in making a fix that it inspired - and now can't see how to link to a specific line in a PR....

awf commented 3 years ago

And yes! (Although I do note that linear map AD needs neither elim nor the PyTorch hack...)

toelli-msft commented 3 years ago

And yes! (Although I do note that linear map AD needs neither elim nor the PyTorch hack...)

LM AD does have shaped zeros though: https://github.com/microsoft/knossos-ksc/blob/d9d10db1e2e7cbe2ff079d0f0aebf6d27f128c57/src/ksc/AD.hs#L93

awf commented 3 years ago

Agreed, but there's no particularly special case there. And of course LM and SUF both are infinitely better than the PyTorch situation.

toelli-msft commented 3 years ago

there's no particularly special case there

Indeed not. The LM AD zeroes are storing much more than they need to!

awf commented 3 years ago

Agreed, but they should just use shape?

awf commented 3 years ago

And they are optimised away in any case. We don't see lmZero in emitted code? Of course we do see constVec zeros, but we haven't done all the work to rewrite them away.

toelli-msft commented 3 years ago

they should just use shape?

Yes probably. I think the problem is the same and the solution is the same.

And they are optimised away in any case. We don't see lmZero in emitted code? Of course we do see constVec zeros, but we haven't done all the work to rewrite them away.

Right, lmZero is rewritten to constVec, or a tuple of zeros, or whatever. The zeros themselves are not optimised away in general.

dcrc2 commented 3 years ago

That is, elim can be a primitive function in ksc -- but it will never appear in the .kso, and Cgen and other backends won't need to know how to generate code for it. (Similarly with getZero or [sufrevpass elim]: these can always be rewritten in terms of literal zeros and constVec.)

Are you proposing a subdivision of prims into those that can appear in a .kso and those that can't? How do we make that distinction? In some sense ts_scale is an implementation detail of AD too. How should we determine whether it is allowed to appear in a .kso?

I do think we should try to minimize the number of prims that can appear in a .kso. And if we decide to make a distinction between prims that can appear in a .kso and those that can't, then that makes it easy to add new "private" prims in ksc without breaking consumers of .kso files.

I'd suggest that the test should be whether it is possible to rewrite the operation in terms of existing prims (or ordinary monomorphic functions) without sacrificing performance.

(Apologies for the delayed reply.)

toelli-msft commented 3 years ago

Interesting. elim is a "SUF KS prim" since it can't be expressed in SUF KS in terms of another prim, but it isn't a "KS prim" since in KS it can just be dropped entirely.