jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.08k stars 2.75k forks source link

Named tensors #5048

Open juliuskunze opened 3 years ago

juliuskunze commented 3 years ago

PyTorch has experimental support for named tensors achieving some compelling design goals while keeping existing code compatible. For example, binop broadcasting is still based on dimension order (unlike in xarray), consistent with standard NumPy/JAX/... semantics, but checks that aligned dimension names match.

It would be great to have named tensors that work both in op-by-op and under function transformations in JAX.

@shoyer In https://github.com/google/jax/issues/1565 you mentioned that this could be done by wrapping JAX based on https://github.com/google/jax/pull/611. According to my current understanding, this means:

Is this plan sound? @shoyer @mattjj Would you update (and merge, if successful) https://github.com/google/jax/pull/611 just for this application? In that case, I'd be interested in prototyping a named tensor library for JAX, with a good amount of passion, in accordance with https://github.com/google/jax/issues/1565. (:

Jeevesh8 commented 3 years ago

Have you started working on this @JuliusKunze ?

apaszke commented 3 years ago

We are actually working on something that will pretty much realize the plan that @JuliusKunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators).

juliuskunze commented 3 years ago

@Jeevesh8 No, and now I won't anymore. (: @apaszke That's great to hear! Will this go into the JAX repo?

degregat commented 3 years ago

Do I assume correctly that this evolved into named axes, or is there another module I did not find?

froystig commented 3 years ago

That's correct.

juliuskunze commented 3 years ago

@apaszke @froystig That looks awesome! Rad choice not taking into account order of named axes and broadcasting by name! That's semantically cleaner and probably more future-proof than I expected. (: The thing that I thought would make this impractical is that it's hard to optimize misaligned axes for dot products and similar ops where implicit transposes are needed on device. I guess the performance hit is not so bad or axis order optimization could/should be automated in the future anyway? Curious about your thoughts on this.

+1 for allowing arrays and operations with named axes outside of xmap, i. e. make named axis arrays first-class in jax as suggested above.

Bit0r commented 1 year ago

A more powerful implementation is to use first-class dimensions, and torchdim uses objects as dimension "variables"

Bit0r commented 1 year ago

@apaszke Perhaps it could be further independent of axis position? By utilizing the named tensor feature, operations that do not depend on axis position can be achieved.

  1. For example, the for loop operation can directly specify which axis to loop on, and the framework automatically advances the axis to the first dimension. The entire operation is transparent, and users do not need to write any additional code.
  2. For example, we can directly use axis names and indices on that axis to access tensors.
# tensor.named_shape={'batch':32, 'time':100, 'hidden':200}

t[{'time':0, 'hidden':0}] = 1000 # Select tensor with axis time 0 and axis hidden 0, and set tensor to 1000 with broadcast.

for t in tensor['time']:
# Jax automatically performs dimension permutation for operations: tensor: batch, time, hidden -> time, batch, hidden
# t.named_shape = {'batch':32, 'hidden':200}
    ...