google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.15k stars 2.67k forks source link

Allow reshaped buffer donation #11036

Open YouJiacheng opened 2 years ago

YouJiacheng commented 2 years ago
@partial(jax.jit, donate_argnums=0, static_argnums=1) 
def reshape(a, ord):
    return jax.numpy.reshape(a, ord)

d = 100
v = jax.numpy.zeros((d * d, d * d))
v = reshape(v, (d, d, d, d))

I expect this function should be optimized to something like .view in pytorch. But I got

UserWarning: Some donated buffers were not usable: ShapedArray(float32[100,100,100,100]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.

The donation may also be unused if there is no output whose shape matches the donation

11013

@jakevdp

hawkinsp commented 2 years ago

The key piece of code is this function: https://github.com/google/jax/blob/7098088f4eb15cf750398889e4341dbc15cda1b3/jax/interpreters/mlir.py#L580

which looks amongst the donated buffers to find buffers to reuse for outputs. Currently it searches for an exact aval (type/shape) match, since that is safe on all platforms.

You are correct to observe that on GPU it would also be correct to use a more relaxed condition: namely, that the element type matches and the product of the dimension sizes match, rather than requiring exact equality of the dimensions.

This more relaxed condition would not be correct in general on TPU. TPUs use a tiled memory layout, and some dimensions may be padded. Exactly which depends on a layout heuristic. So the mapping from the shape of the buffer to the number of bytes that it consumes is more complicated.

If you wanted to send a PR that relaxed the condition for GPU only, that'd work!

YouJiacheng commented 2 years ago

Thanks for reply! I will take a look at this.

Dongyeongkim commented 5 days ago

@hawkinsp Hello, I would like to ask if the reshaped buffer donation is possible for CPUs after the issue https://github.com/google/jax/issues/1733 has been resolved(I guess this is not a problem if the problem is related to the tiled memory layout, but just for clarification). Thank you in advance!