elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 193 forks source link

Runtime dynamic shapes for more flexible Nx implementations #610

Closed seanmor5 closed 2 years ago

seanmor5 commented 2 years ago

While implementing Axon ONNX I've hit snags a few times when implementing operators which depend on runtime values for shape-based operations. Some examples of ONNX operators which depend on runtime values whereas we depend on static values:

While my primary motivation for researching this was ONNX compatibility, I think if we slowly introduce some more dynamic operations into the API we will give user's more flexibility. The main focus of this will be on XLA implementations because those are what requires the most thought. From what I can tell, TorchX should support all of these dynamic operations out of the box. The binary backend should be easy to implement here as well.

Other Frameworks

PyTorch is pretty flexible and supports all of these IIRC. Jax has an (undocumented) djit which is basically a dynamic-shape jit. You can see their implementations of various operators here: https://github.com/google/jax/blob/68e9e1c26d5d9439d03d09a10b8f9b26e8258383/jax/experimental/djax.py

I think their implementation does not rely on the XLA transformations I'll discuss here (but I could be wrong).

Implementation

Static shape requirements are mostly imposed to enable optimizations with buffer allocations (IIUC). We can work around this requirement with dynamic dimensions. A dynamic dimension is just a dimension with a runtime upper-bound. We guarantee to XLA (and other compilers) that the size of a dimension will not exceed a certain upper bound, but we can change it within that boundary however we want. XLA has a few operators which allow us to work with dynamic dimensions:

For some of the operators above, we can pretty much get away with a dynamic reshape: reshape, squeeze, unsqueeze.

For resize and expand, I believe we can implement them in terms of SetDimensionSize and GetDimensionSize.

For split, I believe it will be a DynamicReshape and a SetDimensionSize.

I have some ideas for how to implement non-zero, but I haven't gotten anything concrete yet. My idea is basically to get non-zero values, compute the size of each axis with a sum, argsort to get indices, and then set the dimension size to the runtime values computed from the sum. It's important to note that once we have non-zero we get boolean indexing for free (just by simply passing the indices to any one of our existing indexing operations).

ReduceSum I think is going to be more challenging - and might not even be worth pursuing.

AutoGrad

I know Jax's djaxprs implement some AD rules, but I haven't gotten too far into them. I think for an initial implementation it would be best to just raise unimplemented.

Compatibility with other backends

It's possible that other compilers and backends will not support dynamic shapes. I think though, our public-facing API should not be more restrictive than the backends that implement it. We should take a stance where we allow whatever is reasonably possible and defer to the backends to be more restrictive where necessary.

seanmor5 commented 2 years ago

I should add a note on Expr handling - I think that we can just get away with treating our shape rules as if dynamic operations always return shapes with an upper-bound. We could also mark dimensions as being dynamic explicitly when creating the expression.

dmorn commented 2 years ago

@seanmor5 Slice is part of the example list as well. I'm on a BERT model and it requires dynamic Slice and Reshape operations.

josevalim commented 2 years ago

Although this functionality exists in XLA, it is not used by JAX, and that makes us skeptical about relying on it. Especially because the examples that seem to use it, AxonONNX are not actually dynamic, but it is the way ONNX has to export it. So we don’t intend to further pursue it for now. Thanks @dmorn!