Closed dastrobu closed 4 months ago
This looks super cool! One thing I'm wondering is what is the expected behavior when the mx.array
goes out of scope but I still have a pointer to the memory via say memoryview
(or another numpy array). Is that undefined? From what I can tell the array will still be freed and so will it's data so writing to the memory pointed to by the memorview is no good?
@awni thanks for the feedback. I'll be happy to contribute if this is a desired feature.
Regarding your question:
Is that undefined? From what I can tell the array will still be freed and so will its data so writing to the memory pointed to by the memorview is no good?
I think this is handled by Py_buffer.obj
. The python buffer holds a reference to the exporting object, which will prevent deletion of the array. So it should be well defined behavior. Let's hope pybind11 is handling it right for us. Since they document this feature extensively, I guess they will.
Maybe we could take this further by constructing a numpy array via the memory view and return it form __array__
. Then this would be a writable numpy array view holding a reference to the mlx array. In pseudo code:
def __array__(self):
return np.array(self, copy=False)
This should work for all but bfloat16, where we could still return a converted readonly float32 np array.
Even better: __array__
could be removed altogether. There is only the edge case of auto conversion from bfloat16
to float32
. But it might be better to do this explicitly anyway.
Removing __array__
fixes https://github.com/ml-explore/mlx-data/issues/20 as well (tested locally). In essence, I think mlx-data could just use mlx arrays in all places after this change.
@awni I updated this draft and I think it is quite a simplification and generalization compared to __array__
. I wasn't aware of this when starting with the buffer protocol.
This looks great to me. I don't see any obvious reasons not to go this path. It seems strictly more useful than __array__
and keeps things pretty simple which is nice!
Adding @jagrit06 @angeloskath in case they see anything I might be missing here.
@dastrobu I think you should PR this!
Let me check some details and add some test cases, then I'll turn the draft into a PR for review.
PR is ready for review. Looking forward to your feedback.
One thing I'm wondering is what is the expected behavior when the mx.array goes out of scope [...] Is that undefined?
@awni I added two test cases to verify references are handled as expected.
This looks great to me. I am trying to remember why we opted not to implement the buffer protocol. I can't think of any reason. However, we were returning non-writeable arrays so that it doesn't cause confusing behavior.
For instance the example provided (thanks for writing a nice document btw) is not quite right. The reason it returns 0 is because you are making a new array. If instead you were simply returning x @ x
as in function f
, then you would get the correct gradient. And that is way more confusing than getting 0s.
So to sum up, I like it and I will do a closer review later but I am leaning towards returning non-writeable buffers. What do you guys think @awni and @jagrit06 ?
def g(x):
x_view = np.array(x, copy=False)
x_view[:] += 1 # modify memory without telling mx
return mx.array(x @ x)
That code is a pretty explicit situation (copy=False
+ taking grads). Dangerous words perhaps, but i can't imagine anyone expecting function transforms to "work" (not even sure what the right thing to expect there is), when you do that.
I'm not advocating for writeable buffers, just that case doesn't have much weight for me.
What exactly do we lose by making buffers read only (the safer and hence default option)?
mx.array
into mx.data
ops is a nice feature. Could we get that behavior by casting to numpy in mx.data before performing the op (or something like that)?If instead you were simply returning x @ x as in function f, then you would get the correct gradient. And that is way more confusing than getting 0s.
That code is a pretty explicit situation (
copy=False
+ taking grads). Dangerous words perhaps, but i can't imagine anyone expecting function transforms to "work" (not even sure what the right thing to expect there is), when you do that.
Thank you for examining the example. Upon reflection, it appears to be unrelated to readonly views, as demonstrated by:
def g(x):
x = np.array(x, copy=True)
x[:] += 1
return mx.array(x @ x)
This is doable with the current implementation, and of course it would also not get the right gradients. Nevertheless, unintended conversions may occur when interfacing with certain libraries. It seems advisable to highlight this behavior to developers in any case.
So this doesn't appear to be inherently advantageous or detrimental to writable buffers.
Sending mx.array into mx.data ops is a nice feature. Could we get that behavior by casting to numpy in mx.data before performing the op (or something like that)?
This was actually what started me to look into this. If buffers are exposed read only, mx.data should brute force write to them anyway.
Pipelines are currently implemented such that buffers are modified in place. So if buffers should stay read only this would require some refactoring. I am not sure what performance impact that would have on pipelines operating on large datasets. I really like the concept of buffers and streams as they are quite powerful. But always remembering when to use a np array and when to use a mx array really causes a headache.
Once there is agreement on the direction to go, I can update the docs accordingly.
This was actually what started me to look into this. If buffers are exposed read only, mx.data should brute force write to them anyway.
I thought about this a bit more, and I appreciate this use case (and similar use cases), but I don't think they are worth making the buffer writeable for. Putting MLX writeable array buffers into black box APIs seems risky in general.
Shall we go with read-only for now? Is everyone ok with that? @angeloskath @dastrobu ?
Putting MLX writeable array buffers into black box APIs seems risky in general.
Could you elaborate on your specific concerns? Furthermore, it would be valuable to discuss the recommended approach for mlx-data or similar use cases.
Take, for instance, the image_random_h_flip
function. It serves as an ideal model for potential extensions, with the capability to support both np and mx arrays through the buffer protocol. However, when dealing with read-only buffers, the question arises how the implementation C extension should look like?
I am open to the read-only solution, provided we can establish a clear and viable path for addressing such cases.
@dastrobu we discussed a bit offline and the conclusion was that we can keep buffers writable (as you have it). We'll see how it goes, if there are some hiccups from we can come back and change it.
Did you have anything else to add to this PR? If not I can take a final pass and we can land it?
@dastrobu we discussed a bit offline and the conclusion was that we can keep buffers writable (as you have it). We'll see how it goes, if there are some hiccups from we can come back and change it.
Great approach. I'll be the first one to create an issue on a hiccup if I find one 😉
Did you have anything else to add to this PR? If not I can take a final pass and we can land it?
Regarding the PR, I've incorporated all our discussions and considered the comments from @angeloskath in updating the documentation. Please take a moment to review the revised documentation. Beyond that, the PR is good to go from my perspective.
Lastly, a big thank you to all the reviewers for engaging in such a positive and constructive discussion. In my opinion, the success of open source projects is not solely dependent on good code but also on individuals like you who invest time and effort in discussions with community members like myself.
I left a couple of minor comments there could you take a look?
Thanks, everything should be resolved now.
Proposed changes
Make array conform to the Buffer Protocol.
The method
__array__
is replaced by implementing the Python Buffer Protocol.Summary:
__array__
is removed.np.array(mx.ones(1, dtype=mx.bfloat16))
will fail, butnp.array(mx.ones(1, dtype=mx.bfloat16).astype(mx.float32))
will work.mlx-data
may use mx arrays instead of numpy arrays, as memory views are writable. See mlx-data issue #20.test_buffer_protocol_tf
.Closes #320
Writable Memory Views
Enabling writable memory views has both advantages and a drawback.
As highlighted in the summary,
mlx-data
relies on in-place modifications to buffers. Ifmlx-data
were to usemlx
arrays in the future, which is a sensible idea,mlx
arrays must expose writable buffers, as demonstrated in this pull request. An alternative would be to refactormlx-data
such that it always returns new arrays on transformations and does not modify memory in place.The downside is that direct modification of memory through buffers will not be captured by the grad tracer. Thus, this pull request allows developers to unintentionally cause issues, as illustrated in the following test case:
If preventing such issues is desired, the pull request can be modified to return read-only buffers easily, albeit with the drawback of not supporting in-place operations as done in
mlx-data
. For comparison, TensorFlow has implemented the buffer protocol such that it returns read-only buffers.From my perspective, having the flexibility to perform in-place operations on buffers outweighs the concern of potential broken gradients. I would like reviewers to pay special attention to this aspect.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes