Closed awf closed 2 years ago
Great - so vmap
maps over the outermost dimension, whereas elementwise
maps over all elements. Thanks to @ryotatomioka for helping me understand why vmap
can't deduce the number of dimensions to map over automatically. (If you'd already compiled a ks function then this would be well-defined, but you can't compile the ks function without providing an example input, and you can't get an example input without knowing how many dimensions to map over...)
Ryota also mentioned that similar libraries allow vmap(vmap(f))
for mapping over more than one dimension. Do we want to support this? It looks like much of this PR would need be rewritten to make this possible.
If we do have multidimensional vmap
then we wouldn't really need elementwise
as well. I wouldn't mind removing elementwise, even though that's the version that I originally implemented, if that would avoid having duplicated functionality.
Good questions. Roughly this:
So while we work that out, the next thing I really want is something that generalizes elementwise by passing an example input for the innermost function; then we get the behaviour we want of mapping a function which takes an MxN input over some arbitrary number of mapping dimensions: PxQx...xMxN.
So I'm inclined to work on the latter next, then go back to composability of vmap.
So while we work that out, the next thing I really want is something that generalizes elementwise by passing an example input for the innermost function; then we get the behaviour we want of mapping a function which takes an MxN input over some arbitrary number of mapping dimensions: PxQx...xMxN.
Right I see that would work (requiring an explicit compile
in order to provide the example input). This PR looks like a good starting point then.
No new functionality; this PR replaces "elementwise" in most places with a new class
VecSpec
, which has derived classesVecSpec_None
,VecSpec_Elementwise
,VecSpec_VMap
., with meaning as follows:Suppose a knossos function
is called with x a tensor of size [PxMxN].
The Vectorization specification VecSpec subclass decides how f is mapped over this argument as follows: