Open YouJiacheng opened 2 years ago
The key piece of code is this function: https://github.com/google/jax/blob/7098088f4eb15cf750398889e4341dbc15cda1b3/jax/interpreters/mlir.py#L580
which looks amongst the donated buffers to find buffers to reuse for outputs. Currently it searches for an exact aval
(type/shape) match, since that is safe on all platforms.
You are correct to observe that on GPU it would also be correct to use a more relaxed condition: namely, that the element type matches and the product of the dimension sizes match, rather than requiring exact equality of the dimensions.
This more relaxed condition would not be correct in general on TPU. TPUs use a tiled memory layout, and some dimensions may be padded. Exactly which depends on a layout heuristic. So the mapping from the shape of the buffer to the number of bytes that it consumes is more complicated.
If you wanted to send a PR that relaxed the condition for GPU only, that'd work!
Thanks for reply! I will take a look at this.
@hawkinsp Hello, I would like to ask if the reshaped buffer donation is possible for CPUs after the issue https://github.com/google/jax/issues/1733 has been resolved(I guess this is not a problem if the problem is related to the tiled memory layout, but just for clarification). Thank you in advance!
I expect this function should be optimized to something like
.view
in pytorch. But I got11013
@jakevdp