Closed seanmor5 closed 2 years ago
I should add a note on Expr
handling - I think that we can just get away with treating our shape rules as if dynamic operations always return shapes with an upper-bound. We could also mark dimensions as being dynamic explicitly when creating the expression.
@seanmor5 Slice is part of the example list as well. I'm on a BERT model and it requires dynamic Slice and Reshape operations.
Although this functionality exists in XLA, it is not used by JAX, and that makes us skeptical about relying on it. Especially because the examples that seem to use it, AxonONNX are not actually dynamic, but it is the way ONNX has to export it. So we don’t intend to further pursue it for now. Thanks @dmorn!
While implementing Axon ONNX I've hit snags a few times when implementing operators which depend on runtime values for shape-based operations. Some examples of ONNX operators which depend on runtime values whereas we depend on static values:
While my primary motivation for researching this was ONNX compatibility, I think if we slowly introduce some more dynamic operations into the API we will give user's more flexibility. The main focus of this will be on XLA implementations because those are what requires the most thought. From what I can tell, TorchX should support all of these dynamic operations out of the box. The binary backend should be easy to implement here as well.
Other Frameworks
PyTorch is pretty flexible and supports all of these IIRC. Jax has an (undocumented)
djit
which is basically a dynamic-shape jit. You can see their implementations of various operators here: https://github.com/google/jax/blob/68e9e1c26d5d9439d03d09a10b8f9b26e8258383/jax/experimental/djax.pyI think their implementation does not rely on the XLA transformations I'll discuss here (but I could be wrong).
Implementation
Static shape requirements are mostly imposed to enable optimizations with buffer allocations (IIUC). We can work around this requirement with dynamic dimensions. A dynamic dimension is just a dimension with a runtime upper-bound. We guarantee to XLA (and other compilers) that the size of a dimension will not exceed a certain upper bound, but we can change it within that boundary however we want. XLA has a few operators which allow us to work with dynamic dimensions:
For some of the operators above, we can pretty much get away with a dynamic reshape:
reshape
,squeeze
,unsqueeze
.For
resize
andexpand
, I believe we can implement them in terms ofSetDimensionSize
andGetDimensionSize
.For
split
, I believe it will be aDynamicReshape
and aSetDimensionSize
.I have some ideas for how to implement non-zero, but I haven't gotten anything concrete yet. My idea is basically to get non-zero values, compute the size of each axis with a sum, argsort to get indices, and then set the dimension size to the runtime values computed from the sum. It's important to note that once we have non-zero we get boolean indexing for free (just by simply passing the indices to any one of our existing indexing operations).
ReduceSum I think is going to be more challenging - and might not even be worth pursuing.
AutoGrad
I know Jax's djaxprs implement some AD rules, but I haven't gotten too far into them. I think for an initial implementation it would be best to just raise unimplemented.
Compatibility with other backends
It's possible that other compilers and backends will not support dynamic shapes. I think though, our public-facing API should not be more restrictive than the backends that implement it. We should take a stance where we allow whatever is reasonably possible and defer to the backends to be more restrictive where necessary.