Open toelli-msft opened 3 years ago
For 1/2/3: I think you mean something like this.
Instead of returning | return |
---|---|
constVec n val | (n, val) |
build n (i -> constVec i (f i)) | build n (i -> (i, f i)) |
And the concern is that the RHS type might not uniquely define the LHS activity? Essentially we need to deal with the cases in getZero, and the encoding needs to be sufficient to distinguish those cases. So it's a finite number of cases, we should check.
Can we use this strategy to make all BOGs for elim small? Unfortunately not! If r : Vec (Vec Float) were ragged then the BOG would need to at least be a Vec Integer describing the length of each nested vector.
Yes, this is fine -- this is what shape
has to do anyway. It's still smaller than the Vec (Vec Float)
it replaces.
2/3. We can probably eliminate the perf penalty for dummy in a lot of cases.
So, yes: 4 seems worth looking at. It may be that it ends up having to be "some" prims only, in which case we may as well start with a list of prims that get monomorphised, where that list contains only elim
to begin with.
I am still wondering whether, instead of thinking about cunning ways to represent zero, we could explore the functional representation, in which zero is the identity function and addition is function composition. If successful that would render all these questions moot.
To expand a bit, for every type T
we have to choose
ElimBOG{T}
elimBog : T -> ElimBOG{T}
bogZero : ElimBOG{T} -> dT
such that elimBog (bogZero t)
is the zero of the same shape as t
. Furthermore, unless we take the step of adding a structured type to the prim elim
then we have the additional requirement
ElimBOG{T}
never coincides with ElimBOG{U}
when T
and U
are different types.There are two things we'd like to achieve with this choice
bogZero t
being smaller is betterelimBog (bogZero t)
being quicker to run is betterFor example, when T = Vec Float
we can make the choice
ElimBOG{Vec Float} = Vec Float
elimBog t = constVec (size t) 0.0
bogZero bog = bog
but this wastes space. A choice with smaller BOG yet that runs in the same time is
ElimBOG{Vec Float} = Integer
elimBog t = size t
bogZero bog = constVec bog 0.0
The question at hand is which choice we should make for ElimBOG
. elimBOG
and bogZero
.
@awf's previous comment
this is what
shape
has to do anyway
makes me think that this question is exactly a question of shape (except perhaps the additional requirement around coincidence). I will look into the shapes stuff to see if it already solves this problem.
@simonpj suggests
we could explore the functional representation, in which zero is the identity function and addition is function composition
I shall write up more on that when I get a moment.
I think that behind @simonpj's suggestion lies the idea that, regardless of how we store them in the bog, we shouldn't be creating and adding large arrays of zeroes to things in any case. There are two broad approaches we could take:
@simonpj and I discussed 2. We concluded that using a functional representation T -> T
or dT -> dT
for the tangent space doesn't help. To see why, consider
constVec : (Int, Float) -> Vec Float
Its reverse pass under a functional representation would be something like
sufrevpass$constVec (Vec Float -> Vec Float, BOG) -> (Int, Float) -> (Int, Float)
How can we make use of the Vec Float -> Vec Float
that is passed in when we have no Vec Float
to apply it to? Well, there is one special value of Vec Float
available to us, the zero vector. But now it seems we've just moved the problem of creating zero vectors elsewhere.
There is an alternative choice of tangent space. We could come up with a sum type representing all operations that can be applied to a tangent space, for example
data Delta a where
Add :: Delta a -> Delta a -> Delta a
Zero :: Delta a
Leaf :: Float -> Delta Float
IncAt :: Int -> Delta a -> Delta (Vec a)
DFst :: Delta a -> Delta (a,b)
DSnd :: Delta b -> Delta (a,b)
But this poses more problems:
IncAt
s (or Leaf
s) have accumulated we ideally want to accumulate them into a dense representation somehow.To expand a bit, for every type
T
we have to choose
- a type
ElimBOG{T}
- a function
elimBog : T -> ElimBOG{T}
- a function
bogZero : ElimBOG{T} -> dT
such that
bogZero (elimBog t)
is the zero of the same shape ast
.
As you've hinted above, I think what you've presented here is precisely the definition of ShapeType{T}
, with elimBog
being the existing primitive function shape
. (bogZero
is a function which doesn't currently exist in ksc, though it has an implementation in knossos.h for use in buildFromSparse
). Indeed, if shapes are to be useful, it seems important that ShapeType{T}
is the most efficient representation possible, so if ShapeType{T}
seems suboptimal, we should really fix that by changing ShapeType{T}
, rather than by introducing a separate ElimBOG{T}
.
Currently ShapeType{Vec Float}
is Vec (Tuple)
, and this may not be optimal. As we were discussing yesterday, we could optimize this in the backend, by having a special representation for Vec (Tuple)
which was as efficient as an Integer
. (In Cgen this could be a template specialization for tensor<N, tuple<>>
.) But since we'd have to do this optimization in every backend (and RLO would have to know about it in order to calculate the cost correctly), it might be better to change the ks definition instead.
(But I also agree that it would be a lot better to avoid creating zeros at all, if that's possible.)
It feels fine to have the shape rules include a special case for Vec Scalar
:
ShapeType{Int} = Tuple []
ShapeType{Vec Float} = Int
ShapeType{Vec Int} = Int
ShapeType{Tuple [T1 .. Tn]} = Tuple [ShapeType{T1} .. ShapeType{Tn}]
ShapeType{Vec T} = Vec ShapeType{T}
We don't require that T
be inferrable from ShapeType{T}
, as we can always arrange to have T
nearby.
If we haven't already, let's make a doc somewhere describing ShapeType{T} -- either in a Haskell Note, or perhaps Shapes.md beside Shapes.hs, or in doc/sphinx/Transforms/Shape.rst.
And in that doc, we should write the invariants we want. I believe one invariant might be
S{T} = ShapeType{T} is such that from a value of S{T}, and T, a zero object can be constructed.
Or
mkZero@T (shape z) == z when z is a zero of type t.
I would also really like a doc for the other transformations that is not AD.pdf
On prims and structured names: I think we should be regarding elim
as an implementation detail of AD, so that the answer to this question shouldn't actually matter outside of ksc.
That is, elim
can be a primitive function in ksc -- but it will never appear in the .kso, and Cgen and other backends won't need to know how to generate code for it. (Similarly with getZero
or [sufrevpass elim]
: these can always be rewritten in terms of literal zeros and constVec
.)
This would be similar to how we deal with constructs like [rev constVec]
at the moment: this can arise in unoptimized AD, but is always rewritten to a form which uses base primitives. Cgen does not know how to generate code for any derived functions whose base is a primitive. I think this is a good thing, and we should be minimizing the number of primitives that backends need to know about.
If elim
doesn't appear in the .kso then I think the simplest answer then becomes:
This would mean that the .kso is unchanged from what we have at the moment, while still solving the disambiguation problem.
Alternatively if you want to write out all prims as structured names, this will need some modification to our other parsers; but they would be free to ignore the type information so the changes required would be fairly minimal.
Alternatively if you want to write out all prims as structured names, this will need some modification to our other parsers; but they would be free to ignore the type information so the changes required would be fairly minimal.
Exactly -- I don't see this as a real problem, and it would maybe be better than having more special cases around prims.
Let's recall:
eq
, ts_add
, build
, fold
, map
) are just function templates -- we would expect to modify KSC to emit edefs for each used instance. Those edefs would obtain structured names in the normal way (either literally emitted, or trivially assembled at kso parse time).get$1
etc, and those can be special-cased.Alternatively if you want to write out all prims as structured names, this will need some modification to our other parsers; but they would be free to ignore the type information so the changes required would be fairly minimal.
I was thinking of things like ts_add
when I wrote this; but I'm not so sure it makes sense for the "very special" primitives like build
and get$1
. So it feels like we're always going to end up with two kinds of prims:
get$i
, build
, fold
, ...elim
, ...And the large majority of prims could reasonably belong to either category. But I'd have a preference for saying that anything that the backends need to understand should be in the first category; any primitives in the second category should be rewritten to other forms by ksc and not appear in the .kso
. Then the backends only have one kind of prim to worry about.
That is, elim can be a primitive function in ksc -- but it will never appear in the .kso, and Cgen and other backends won't need to know how to generate code for it. (Similarly with getZero or [sufrevpass elim]: these can always be rewritten in terms of literal zeros and constVec.)
Are you proposing a subdivision of prims into those that can appear in a .kso
and those that can't? How do we make that distinction? In some sense ts_scale
is an implementation detail of AD too. How should we determine whether it is allowed to appear in a .kso
?
Seems like we're not the only ones having problems with elim! š
Seems like we're not the only ones having problems with elim! š
Sorry - I just broke your link, in making a fix that it inspired - and now can't see how to link to a specific line in a PR....
And yes! (Although I do note that linear map AD needs neither elim nor the PyTorch hack...)
And yes! (Although I do note that linear map AD needs neither elim nor the PyTorch hack...)
LM AD does have shaped zeros though: https://github.com/microsoft/knossos-ksc/blob/d9d10db1e2e7cbe2ff079d0f0aebf6d27f128c57/src/ksc/AD.hs#L93
Agreed, but there's no particularly special case there. And of course LM and SUF both are infinitely better than the PyTorch situation.
there's no particularly special case there
Indeed not. The LM AD zeroes are storing much more than they need to!
Agreed, but they should just use shape?
And they are optimised away in any case. We don't see lmZero in emitted code? Of course we do see constVec zeros, but we haven't done all the work to rewrite them away.
they should just use shape?
Yes probably. I think the problem is the same and the solution is the same.
And they are optimised away in any case. We don't see lmZero in emitted code? Of course we do see constVec zeros, but we haven't done all the work to rewrite them away.
Right, lmZero
is rewritten to constVec
, or a tuple of zeros, or whatever. The zeros themselves are not optimised away in general.
That is, elim can be a primitive function in ksc -- but it will never appear in the .kso, and Cgen and other backends won't need to know how to generate code for it. (Similarly with getZero or [sufrevpass elim]: these can always be rewritten in terms of literal zeros and constVec.)
Are you proposing a subdivision of prims into those that can appear in a
.kso
and those that can't? How do we make that distinction? In some sensets_scale
is an implementation detail of AD too. How should we determine whether it is allowed to appear in a.kso
?
I do think we should try to minimize the number of prims that can appear in a .kso
. And if we decide to make a distinction between prims that can appear in a .kso
and those that can't, then that makes it easy to add new "private" prims in ksc without breaking consumers of .kso
files.
I'd suggest that the test should be whether it is possible to rewrite the operation in terms of existing prims (or ordinary monomorphic functions) without sacrificing performance.
(Apologies for the delayed reply.)
Interesting. elim
is a "SUF KS prim" since it can't be expressed in SUF KS in terms of another prim, but it isn't a "KS prim" since in KS it can just be dropped entirely.
We have a prim
elim :: A -> ()
. Its SUF/BOG-AD reverse pass is[sufrevpass elim] :: ((), BOG) -> dA
, where the return value is the zero of typedA
. We are free to chooseBOG
as long as as the zero of typedA
can be constructed from it. The easiest choice forBOG
ismkTangentZero
of the argument thatelim
was supplied with, that isBut when
A = Vec Float
this BOG is much too big! Ifa
is aVec Float
of sizen
then all we need to recovermkTangentZero a
isn
. Much better! Can we use this strategy to make all BOGs forelim
small? Unfortunately not! Ifr : Vec (Vec Float)
were ragged then the BOG would need to at least be aVec Integer
describing the length of each nested vector.Still, we can take advantage of this strategy in some special cases. There is still a difficulty though. If we choose
Integer
for the BOG ofelim
onVec Float
then what do we choose for the BOG ofelim
onVec (Float, Float)
etc.? Given that[sufrevpass elim] :: ((), BOG) -> dA
we need to be able to determinedA
fromBOG
. Some possibilities:Come up with an ad hoc encoding, e.g. the BOG of
Vec Float
could be((), Integer)
and the BOG ofVec (Float, Float)
could be(((), ()), Integer)
. I don't hold out much hope of this possibility being feasible to work with in the general case.Use
Dummy
for tagging values. However,Dummy
is not erased at run time so it will incur a performance penalty (perhaps slight). (N.B. we can't eraseDummy
at the moment because it is used to fake up sum types from tuples. If we added real sum types then we could useDummy
for this purpose.)Introduce a new type, like
Dummy
, but that is actually erased at run time.Introduce structuring to the names of prims.
elim : A -> ()
could become[elim A] : A -> ()
,[sufrevpass elim] : ((), BOG) -> dA
would become[sufrevpass [elim A]] : ((), BOG) -> dA
. The problem would then vanish. (Perhaps this is the natural solution. Solving this problem for userfuns was why it was essential to introduce structured names.)Thoughts @awf?