microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Initial plumbing for vmap #1010

Closed awf closed 2 years ago

awf commented 2 years ago

No new functionality; this PR replaces "elementwise" in most places with a new class VecSpec, which has derived classes VecSpec_None, VecSpec_Elementwise, VecSpec_VMap., with meaning as follows:

Suppose a knossos function

       def f(x : Tensor) -> Tensor 

is called with x a tensor of size [PxMxN].

The Vectorization specification VecSpec subclass decides how f is mapped over this argument as follows:

     None: f is compiled to take rank 3 tensors.
     Elementwise: f is compiled to take floats (rank 0), and is computed elementwise.
     VMap: f is compiled to take rank 2 tensors, and mapped over the first dimension.
dcrc2 commented 2 years ago

Great - so vmap maps over the outermost dimension, whereas elementwise maps over all elements. Thanks to @ryotatomioka for helping me understand why vmap can't deduce the number of dimensions to map over automatically. (If you'd already compiled a ks function then this would be well-defined, but you can't compile the ks function without providing an example input, and you can't get an example input without knowing how many dimensions to map over...)

Ryota also mentioned that similar libraries allow vmap(vmap(f)) for mapping over more than one dimension. Do we want to support this? It looks like much of this PR would need be rewritten to make this possible.

If we do have multidimensional vmap then we wouldn't really need elementwise as well. I wouldn't mind removing elementwise, even though that's the version that I originally implemented, if that would avoid having duplicated functionality.

awf commented 2 years ago

Good questions. Roughly this:

So while we work that out, the next thing I really want is something that generalizes elementwise by passing an example input for the innermost function; then we get the behaviour we want of mapping a function which takes an MxN input over some arbitrary number of mapping dimensions: PxQx...xMxN.

So I'm inclined to work on the latter next, then go back to composability of vmap.

dcrc2 commented 2 years ago

So while we work that out, the next thing I really want is something that generalizes elementwise by passing an example input for the innermost function; then we get the behaviour we want of mapping a function which takes an MxN input over some arbitrary number of mapping dimensions: PxQx...xMxN.

Right I see that would work (requiring an explicit compile in order to provide the example input). This PR looks like a good starting point then.