Open juliuskunze opened 3 years ago
Have you started working on this @JuliusKunze ?
We are actually working on something that will pretty much realize the plan that @JuliusKunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators).
@Jeevesh8 No, and now I won't anymore. (: @apaszke That's great to hear! Will this go into the JAX repo?
Do I assume correctly that this evolved into named axes, or is there another module I did not find?
That's correct.
@apaszke @froystig That looks awesome! Rad choice not taking into account order of named axes and broadcasting by name! That's semantically cleaner and probably more future-proof than I expected. (: The thing that I thought would make this impractical is that it's hard to optimize misaligned axes for dot products and similar ops where implicit transposes are needed on device. I guess the performance hit is not so bad or axis order optimization could/should be automated in the future anyway? Curious about your thoughts on this.
+1 for allowing arrays and operations with named axes outside of xmap
, i. e. make named axis arrays first-class in jax as suggested above.
A more powerful implementation is to use first-class dimensions, and torchdim uses objects as dimension "variables"
@apaszke Perhaps it could be further independent of axis position? By utilizing the named tensor
feature, operations that do not depend on axis position can be achieved.
# tensor.named_shape={'batch':32, 'time':100, 'hidden':200}
t[{'time':0, 'hidden':0}] = 1000 # Select tensor with axis time 0 and axis hidden 0, and set tensor to 1000 with broadcast.
for t in tensor['time']:
# Jax automatically performs dimension permutation for operations: tensor: batch, time, hidden -> time, batch, hidden
# t.named_shape = {'batch':32, 'hidden':200}
...
PyTorch has experimental support for named tensors achieving some compelling design goals while keeping existing code compatible. For example, binop broadcasting is still based on dimension order (unlike in xarray), consistent with standard NumPy/JAX/... semantics, but checks that aligned dimension names match.
It would be great to have named tensors that work both in op-by-op and under function transformations in JAX.
@shoyer In https://github.com/google/jax/issues/1565 you mentioned that this could be done by wrapping JAX based on https://github.com/google/jax/pull/611. According to my current understanding, this means:
eval_names
transform.NamedDeviceArray
subtype ofDeviceArray
that adds anames
property.NamedDeviceArray
s. For that,jax.numpy
, wrapping each op with thenamed
transform.NamedDeviceArray
using https://github.com/google/jax/pull/611 (+1 for merging). Alternatively, one could rewritejax.numpy
usingnumpy_dispatch.get_array_module
from https://github.com/google/jax/pull/4076 (appears cumbersome).jit
ted functions propagate names when applied toNamedDeviceArray
s.Is this plan sound? @shoyer @mattjj Would you update (and merge, if successful) https://github.com/google/jax/pull/611 just for this application? In that case, I'd be interested in prototyping a named tensor library for JAX, with a good amount of passion, in accordance with https://github.com/google/jax/issues/1565. (: