jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Support for ragged arrays, like torch.nested #17863

Open beneisner opened 1 year ago

beneisner commented 1 year ago

One of the quite compelling new additions to the torch ecosystem is NestedTensors, see https://pytorch.org/docs/stable/nested.html.

Basically, this is a new primitive in the Torch ecosystem which allows for tensors which are "ragged" along one dimension (aka a tensor with shape B x ? x 3), as well as a bunch of ops which have been optimized to support this style of batching.

Ragged batches are very common in a number of domains; variable-length sequences in NLP tasks, 3D vision tasks such as Point Cloud and Mesh analysis (which often have variable numbers of nodes+faces), graph processing, etc. When batching for parallel/vectorized computation, most researchers tend to use padding operations or do some clever (manual) operations on tensors of shape ((sum N_i) x 3 where N_i is the number of elements in the i-th element of a batch.

Examples of such clever vectorization:

There are a few well-known tricks for doing these kinds of operations (using existing operations in tensor libraries, as well as some custom accelerator code for specific ops) - aggregating them as primitives inside a Tensor library directly (and attempting to cover as many commonly-used ops as possible) would be extremely useful for the community.

For instance, in torch.nested, you can do the following:

import torch
import torch.nested

a = torch.rand((10, 3))
b = torch.rand((20, 3))
R = torch.rand((3, 3))

nt = torch.nested.nested_tensor([a, b])
R_nt = torch.nested.nested_tensor([R, R])

res = nt @ R_nt

The API is currently a bit rough around the edges, but the eventual goal is to be able to have most/all ops in torch transparently support nested tensors, so that downstream libraries can focus on implementing algorithms instead of handling manual memory/batching layout optimizations (and often doing so inefficiently/incorrectly/with many layers of abstraction).

What would it take for JAX to support something like this as a core primitive? I'm not very familiar with how vectorization/parallelization works under the hood (in any accelerator library), but my expectation is that there are probably many challenges in every part of the core (jit, vmap, pmap). One (potential) upshot of having some sort of support for ragged tensors would be the removal of this rough edge (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#dynamic-shapes), if enough properties can be inferred (like the tensor varying in shape only along a single dimension).

MoFHeka commented 4 months ago

Any progress?

vyeevani commented 4 months ago

I didn't really test this thing so there's probably lurking bugs. But it's a high level sketch of a way to do something slightly more general than ragged arrays: https://gist.github.com/vyeevani/e10c4a92bb74edf51b03d8a05e652049. I don't necessarily think this is something that would have to be added directly into jax. Feels like it can be handled by libraries