Closed RadostW closed 3 years ago
Hi - the issue here is that the call signature of hstack
is that it accepts a single argument, which is a tuple of arrays.
A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into N
array objects that are then passed back to XLA.
I'm not sure what we could do to "fix" this – maybe we could raise an error in the case that a single array is passed to hstack
, to prevent this sort of silent conversion back to a numpy tuple, and require users to pass tuple(arr)
explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function's signature.
What do you think?
Also, I don't think it's generally true that hstack
of a single array can be expressed in terms of a reshape. Here's a counter-example:
>>> import jax.numpy as jnp
>>> x = jnp.arange(12).reshape(3, 2, 2)
>>> jnp.hstack(x)
DeviceArray([[ 0, 1, 4, 5, 8, 9],
[ 2, 3, 6, 7, 10, 11]], dtype=int32)
I'm not sure if there's any alternative here other than to split the array into three and pass them to lax.concat
, which is what hstack
currently does.
Oh, I understand better why this happened. Perhaps we can improve just the case where only one jnp
array is passed as argument.
I'm pretty sure in all cases jnp.hstack
can be expressed with jax.lax.reshape
(note: not jnp.reshape
) due to it's cool feature of optional arg dimensions
.
In case of your example it would be:
>>> import jax numpy as jnp
>>> import jax
>>> x = jnp.arange(12).reshape(3, 2, 2)
>>> jax.lax.reshape(x,(2,6),dimensions=(1,0,2)) - jnp.hstack(x)
DeviceArray([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]], dtype=int32)
Nice, I didn't know about that! If it's not too complicated, I think special-casing the single array case in terms of these kinds of reshapes would be worthwhile - maybe it could all be handled at the level of jnp.concatenate
Hi - the issue here is that the call signature of
hstack
is that it accepts a single argument, which is a tuple of arrays.A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into
N
array objects that are then passed back to XLA.I'm not sure what we could do to "fix" this – maybe we could raise an error in the case that a single array is passed to
hstack
, to prevent this sort of silent conversion back to a numpy tuple, and require users to passtuple(arr)
explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function's signature.What do you think?
Is there a way to see which operations in a program might be undergoing such a forced conversion from XLA to numpy back to XLA? I am currently suffering a significant slow down in a program I have written and I am wondering if it due to similar issues going on in the background.
The main source of this kind of thing is calling __iter__
on a DeviceArray
.
The main source of this kind of thing is calling
__iter__
on aDeviceArray
.
How can I find out if this is the case?
Alternatively, I think it might make sense to just raise an error when a single array is passed into hstack
/vstack
/stack
/concatenate
. We would force the user to write something more explicit like jnp.vstack(list(array))
, which hopefully has more obvious performance implications.
That’s two votes for erroring on implicit tuple conversion - maybe that’s a cleaner route
That’s two votes for erroring on implicit tuple conversion - maybe that’s a cleaner route
Oh, I see you did suggest that already 👍
Erroring implicit tuple conversion :+1:
As for hstack
and vstack
I'd still prefer if there was a special case with reshape, especially since the workaround is not immediately obvious.
I think I have working code for this (up to edge cases such as passing a scalar etc.)
import jax.numpy as jnp
import jax
import numpy as np
def hstack_alternative(array):
shp = array.shape
return jax.lax.reshape(
array,
(shp[1], shp[0] * shp[2]) + shp[3:],
dimensions=((1, 0) + tuple(range(2, len(shp)))),
)
def vstack_alternative(array):
shp = array.shape
return jax.lax.reshape(array, (shp[0] * shp[1],) + shp[2:])
x = jnp.arange(12).reshape(3, 2, 2)
y = jnp.arange(2 * 3 * 5).reshape(2, 3, 5)
z = jnp.arange(2 * 3 * 5 * 7).reshape(2, 3, 5, 7)
def test_hstack_alternative():
assert np.allclose(jnp.hstack(x), hstack_alternative(x))
assert np.allclose(jnp.hstack(y), hstack_alternative(y))
assert np.allclose(jnp.hstack(z), hstack_alternative(z))
def test_vstack_alternative():
assert np.allclose(jnp.vstack(x), vstack_alternative(x))
assert np.allclose(jnp.vstack(y), vstack_alternative(y))
assert np.allclose(jnp.vstack(z), vstack_alternative(z))
That's very cool – it would be worth adding this code-path to vstack
and hstack
in my opinion.
Are you interested in putting together a PR?
I was playing with this a bit - here's the implementation of jnp.concatenate
for array inputs in terms of lax.reshape
:
import jax.numpy as jnp
from jax import lax
from jax._src.util import canonicalize_axis
def _concatenate(x, axis=0):
assert isinstance(x, jnp.ndarray)
if x.ndim == 0:
raise ValueError("Need at least one array to concatenate.")
if x.ndim == 1:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
axis = canonicalize_axis(axis, x.ndim - 1)
shape = x.shape[1:axis + 1] + (x.shape[0] * x.shape[axis + 1],) + x.shape[axis + 2:]
dimensions = [*range(1, axis + 1), 0, *range(axis + 1, x.ndim)]
return lax.reshape(x, shape, dimensions)
Quickly tested with:
import numpy as np
x = jnp.arange(2*3*4*5*4*3*2).reshape(2, 3, 4, 5, 4, 3, 2)
for axis in range(1 - x.ndim, x.ndim - 1):
c1 = jnp.concatenate(x, axis=axis)
c2 = _concatenate(x, axis=axis)
np.testing.assert_array_equal(c1, c2)
The implementation is simple enough that I think we should add it to JAX, along with similar approaches for hstack
, vstack
, and perhaps other related functions.
That's very cool – it would be worth adding this code-path to
vstack
andhstack
in my opinion.Are you interested in putting together a PR?
Sure :) I'm new here so if you could help with the github side of things I'll be thankful.
Great! We have a bit of contribution information here: https://jax.readthedocs.io/en/latest/contributing.html#contributing-code-using-pull-requests
Feel free to open a work-in-progress PR if that would be helpful, and let me know if you have any questions
Hi @RadostW - just checking in. Is this still something you'd like to work on? If not, I can plan to put together the fix.
Either way, please let me know - thanks!
@jakevpd I've had a busy week, sry. It seems it worked out in the end, thanks :)
hstack is very inefficient for tensors as it produces jaxpr code with length proportional to size of the traced array.
Compare:
to a better, equivalent code that can be achieved using jnp.reshape
Probably
hstack
can be re-expressed in terms of reshape in general. I'm new tojax
so maybe there are some negative side effects to such approach?Code to reproduce issue: