jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.29k stars 2.78k forks source link

hstack and vstack produce very inefficient jaxpr and jit slowly; possible fix with reshape? #6859

Closed RadostW closed 3 years ago

RadostW commented 3 years ago

hstack is very inefficient for tensors as it produces jaxpr code with length proportional to size of the traced array.

Compare:

{ lambda  ; a b c.
  let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] a
      e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      f = concatenate[ dimension=0 ] d e
      g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] c
      i = concatenate[ dimension=0 ] g h
      j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] f
      k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] i
      l = concatenate[ dimension=0 ] j k
      m = slice[ limit_indices=(1, 2, 2, 2, 3, 3)
                 start_indices=(0, 0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1, 1) ] l
      n = squeeze[ dimensions=(0,) ] m
      o = slice[ limit_indices=(2, 2, 2, 2, 3, 3)
                 start_indices=(1, 0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1, 1) ] l
      p = squeeze[ dimensions=(0,) ] o
      q = concatenate[ dimension=1 ] n p
      r = slice[ limit_indices=(1, 4, 2, 3, 3)
                 start_indices=(0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1) ] q
      s = squeeze[ dimensions=(0,) ] r
      t = slice[ limit_indices=(2, 4, 2, 3, 3)
                 start_indices=(1, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1) ] q
      u = squeeze[ dimensions=(0,) ] t
      v = concatenate[ dimension=1 ] s u
      w = slice[ limit_indices=(1, 4, 3, 3)
                 start_indices=(0, 0, 0, 0)
                 strides=(1, 1, 1, 1) ] v
      x = squeeze[ dimensions=(0,) ] w
      y = slice[ limit_indices=(2, 4, 3, 3)
                 start_indices=(1, 0, 0, 0)
                 strides=(1, 1, 1, 1) ] v
      z = squeeze[ dimensions=(0,) ] y
      ba = slice[ limit_indices=(3, 4, 3, 3)
                  start_indices=(2, 0, 0, 0)
                  strides=(1, 1, 1, 1) ] v
      bb = squeeze[ dimensions=(0,) ] ba
      bc = slice[ limit_indices=(4, 4, 3, 3)
                  start_indices=(3, 0, 0, 0)
                  strides=(1, 1, 1, 1) ] v
      bd = squeeze[ dimensions=(0,) ] bc
      be = concatenate[ dimension=1 ] x z bb bd
      bf = slice[ limit_indices=(1, 12, 3)
                  start_indices=(0, 0, 0)
                  strides=(1, 1, 1) ] be
      bg = squeeze[ dimensions=(0,) ] bf
      bh = slice[ limit_indices=(2, 12, 3)
                  start_indices=(1, 0, 0)
                  strides=(1, 1, 1) ] be
      bi = squeeze[ dimensions=(0,) ] bh
      bj = slice[ limit_indices=(3, 12, 3)
                  start_indices=(2, 0, 0)
                  strides=(1, 1, 1) ] be
      bk = squeeze[ dimensions=(0,) ] bj
      bl = slice[ limit_indices=(4, 12, 3)
                  start_indices=(3, 0, 0)
                  strides=(1, 1, 1) ] be
      bm = squeeze[ dimensions=(0,) ] bl
      bn = concatenate[ dimension=1 ] bg bi bk bm
  in (bn,) }

to a better, equivalent code that can be achieved using jnp.reshape

{ lambda  ; a b c.
  let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] a
      e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      f = concatenate[ dimension=0 ] d e
      g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] c
      i = concatenate[ dimension=0 ] g h
      j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] f
      k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] i
      l = concatenate[ dimension=0 ] j k
      m = reshape[ dimensions=(0, 2, 4, 1, 3, 5)
                   new_sizes=(12, 12) ] l
  in (m,) }

Probably hstack can be re-expressed in terms of reshape in general. I'm new to jax so maybe there are some negative side effects to such approach?


Code to reproduce issue:

import jax
import jax.numpy as jnp

n = 2

mAA = 1.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mBB = 10.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mAB = 2.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))

def stack_hard(AA,AB,BB):
    return jnp.hstack(
        jnp.hstack(
            jnp.hstack(
                jnp.hstack(
                    jnp.array(
                        [[AA,AB],[AB,BB]]
                    )
                )
            )
        )
    )

def stack_easy(AA,AB,BB):
    return  jax.lax.reshape(
                jnp.array([[AA,AB],[AB,BB]]),
                (6*n,6*n),
                dimensions = (0,2,4,1,3,5)
            )

# JIT is very slow in case of larger n
# fast_stack = jax.jit(stack_hard)
# fast_stack(mAA,mBB,mAB)

print('===========================')
print(
    jax.make_jaxpr(stack_hard)(mAA,mAB,mBB)
    )

print('===========================')
print(
    jax.make_jaxpr(stack_easy)(mAA,mAB,mBB)
    )

print(stack_easy(mAA,mAB,mBB))
print(stack_hard(mAA,mAB,mBB))
jakevdp commented 3 years ago

Hi - the issue here is that the call signature of hstack is that it accepts a single argument, which is a tuple of arrays.

A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into N array objects that are then passed back to XLA.

I'm not sure what we could do to "fix" this – maybe we could raise an error in the case that a single array is passed to hstack, to prevent this sort of silent conversion back to a numpy tuple, and require users to pass tuple(arr) explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function's signature.

What do you think?

jakevdp commented 3 years ago

Also, I don't think it's generally true that hstack of a single array can be expressed in terms of a reshape. Here's a counter-example:

>>> import jax.numpy as jnp
>>> x = jnp.arange(12).reshape(3, 2, 2)
>>> jnp.hstack(x)
DeviceArray([[ 0,  1,  4,  5,  8,  9],
             [ 2,  3,  6,  7, 10, 11]], dtype=int32)

I'm not sure if there's any alternative here other than to split the array into three and pass them to lax.concat, which is what hstack currently does.

RadostW commented 3 years ago

Oh, I understand better why this happened. Perhaps we can improve just the case where only one jnp array is passed as argument.

I'm pretty sure in all cases jnp.hstack can be expressed with jax.lax.reshape (note: not jnp.reshape) due to it's cool feature of optional arg dimensions.

In case of your example it would be:

>>> import jax numpy as jnp
>>> import jax
>>> x = jnp.arange(12).reshape(3, 2, 2)
>>> jax.lax.reshape(x,(2,6),dimensions=(1,0,2)) - jnp.hstack(x)
DeviceArray([[0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]], dtype=int32)
jakevdp commented 3 years ago

Nice, I didn't know about that! If it's not too complicated, I think special-casing the single array case in terms of these kinds of reshapes would be worthwhile - maybe it could all be handled at the level of jnp.concatenate

adam-hartshorne commented 3 years ago

Hi - the issue here is that the call signature of hstack is that it accepts a single argument, which is a tuple of arrays.

A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into N array objects that are then passed back to XLA.

I'm not sure what we could do to "fix" this – maybe we could raise an error in the case that a single array is passed to hstack, to prevent this sort of silent conversion back to a numpy tuple, and require users to pass tuple(arr) explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function's signature.

What do you think?

Is there a way to see which operations in a program might be undergoing such a forced conversion from XLA to numpy back to XLA? I am currently suffering a significant slow down in a program I have written and I am wondering if it due to similar issues going on in the background.

jakevdp commented 3 years ago

The main source of this kind of thing is calling __iter__ on a DeviceArray.

adam-hartshorne commented 3 years ago

The main source of this kind of thing is calling __iter__ on a DeviceArray.

How can I find out if this is the case?

shoyer commented 3 years ago

Alternatively, I think it might make sense to just raise an error when a single array is passed into hstack/vstack/stack/concatenate. We would force the user to write something more explicit like jnp.vstack(list(array)), which hopefully has more obvious performance implications.

jakevdp commented 3 years ago

That’s two votes for erroring on implicit tuple conversion - maybe that’s a cleaner route

shoyer commented 3 years ago

That’s two votes for erroring on implicit tuple conversion - maybe that’s a cleaner route

Oh, I see you did suggest that already 👍

RadostW commented 3 years ago

Erroring implicit tuple conversion :+1:

As for hstack and vstack I'd still prefer if there was a special case with reshape, especially since the workaround is not immediately obvious.

I think I have working code for this (up to edge cases such as passing a scalar etc.)

import jax.numpy as jnp
import jax
import numpy as np

def hstack_alternative(array):
    shp = array.shape
    return jax.lax.reshape(
        array,
        (shp[1], shp[0] * shp[2]) + shp[3:],
        dimensions=((1, 0) + tuple(range(2, len(shp)))),
    )

def vstack_alternative(array):
    shp = array.shape
    return jax.lax.reshape(array, (shp[0] * shp[1],) + shp[2:])

x = jnp.arange(12).reshape(3, 2, 2)
y = jnp.arange(2 * 3 * 5).reshape(2, 3, 5)
z = jnp.arange(2 * 3 * 5 * 7).reshape(2, 3, 5, 7)

def test_hstack_alternative():
    assert np.allclose(jnp.hstack(x), hstack_alternative(x))
    assert np.allclose(jnp.hstack(y), hstack_alternative(y))
    assert np.allclose(jnp.hstack(z), hstack_alternative(z))

def test_vstack_alternative():
    assert np.allclose(jnp.vstack(x), vstack_alternative(x))
    assert np.allclose(jnp.vstack(y), vstack_alternative(y))
    assert np.allclose(jnp.vstack(z), vstack_alternative(z))
jakevdp commented 3 years ago

That's very cool – it would be worth adding this code-path to vstack and hstack in my opinion.

Are you interested in putting together a PR?

jakevdp commented 3 years ago

I was playing with this a bit - here's the implementation of jnp.concatenate for array inputs in terms of lax.reshape:

import jax.numpy as jnp
from jax import lax
from jax._src.util import canonicalize_axis

def _concatenate(x, axis=0):
  assert isinstance(x, jnp.ndarray)
  if x.ndim == 0:
    raise ValueError("Need at least one array to concatenate.")
  if x.ndim == 1:
    raise ValueError("Zero-dimensional arrays cannot be concatenated.")
  axis = canonicalize_axis(axis, x.ndim - 1)
  shape = x.shape[1:axis + 1] + (x.shape[0] * x.shape[axis + 1],) + x.shape[axis + 2:]
  dimensions = [*range(1, axis + 1), 0, *range(axis + 1, x.ndim)]
  return lax.reshape(x, shape, dimensions)

Quickly tested with:

import numpy as np

x = jnp.arange(2*3*4*5*4*3*2).reshape(2, 3, 4, 5, 4, 3, 2)
for axis in range(1 - x.ndim, x.ndim - 1):
  c1 = jnp.concatenate(x, axis=axis)
  c2 = _concatenate(x, axis=axis)
  np.testing.assert_array_equal(c1, c2)

The implementation is simple enough that I think we should add it to JAX, along with similar approaches for hstack, vstack, and perhaps other related functions.

RadostW commented 3 years ago

That's very cool – it would be worth adding this code-path to vstack and hstack in my opinion.

Are you interested in putting together a PR?

Sure :) I'm new here so if you could help with the github side of things I'll be thankful.

jakevdp commented 3 years ago

Great! We have a bit of contribution information here: https://jax.readthedocs.io/en/latest/contributing.html#contributing-code-using-pull-requests

Feel free to open a work-in-progress PR if that would be helpful, and let me know if you have any questions

jakevdp commented 3 years ago

Hi @RadostW - just checking in. Is this still something you'd like to work on? If not, I can plan to put together the fix.

Either way, please let me know - thanks!

RadostW commented 3 years ago

@jakevpd I've had a busy week, sry. It seems it worked out in the end, thanks :)