Closed mzbac closed 7 months ago
The problem is the types of the arguments here:
let sortIndexes = argSort(array, axis: -1)
let sorted = array[0..., sortIndexes]
This is passing a RangeExpression
(the 0...
) and an MLXArray
-- there is no subscript
implementation that takes both of those, which is the error you are seeing here.
From the python code:
sorted_probs = probs[..., sorted_indices]
That isn't quite the same as what you have written here -- the ...
(spread) operator will consume all the dimensions up to the last, mostly :-). It gets complicated because sorted_indices
is an array itself and may be multidimensional, so it follows these rules:
I think it might be more along the lines of:
if there was a variant that took an MLXArray
instead of a single Int.
For comparison here is the code that handles an MLXArray
index:
and the python implementation side of this:
So what do we need here? I think we could do any of the following:
subscript
method that took MLXArray
and an axis
mlx_get_item_nd
and allow all kinds of python indexingThe latter is a lot more complicated and doesn't fit as nicely with Swift because of the relaxed types it would take. I think we can do better than subscript(indexes: Any...)
using a protocol, but the relative lack of types will still be a bit unusual swift-wise.
So I am leaning toward the former.
Another good option is to use the take()
function.
Here is the original python in the case where the logits are shaped like 1, 1, 12
:
>>> sorted_probs = probs[..., sorted_indices]
>>> sorted_probs.shape
(1, 1, 1, 1, 12)
>>> sorted_probs
array([[[[[2.41196e-10, 1.78221e-09, 1.31688e-08, ..., 0.0158369, 0.11702, 0.864665]]]]], dtype=float32)
The shape of the result is a little bit surprising, but that is how numpy advanced indexing works.
The take()
function can give us something similar (and in fact it is used in the case where we have a single array of indices):
>>> mx.take(probs, sorted_indices).shape
(1, 1, 12)
is close, but since we have the leading spread operator it would actually be:
>>> mx.take(probs, sorted_indices, -1).shape
(1, 1, 1, 1, 12)
Which can be written in swift like this:
take(probs, sortedIndices, axis: -1)
If that works for you I can update the docs to cover this case as well.
@davidkoski, thank you for the comprehensive explanation. I did a quick test and the take
worked well. It would be great if we could add this to the documentation. I may aslo create a PR in swift-mlx-example for the top_p implementation to demonstrate how it's done in a real example :)
Awesome, I will prep the docs!
Hi @davidkoski,
Just found out that we may still need to support advanced indexing using [..., mlxAtray]. For example, take
works well for selecting mlx arrays from indexing, but we may need to use advanced indexing for setting items in the array. For example, in the mlx-lm's repetition penalty here: https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py#L114
Ah, interesting. Let me look at what it would take to get the whole python indexing implemented.
See #55 -- not done yet, but you can see what it looks like. Check out the tests to see it in use.
Thanks @davidkoski, this is awesome and really appreciate the effort. I will do some tests today and close the issue.
I tried to follow the documentation for advanced indexing to implement top_p as shown in https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py#L21-L25 in Python.
After following the documentation here: https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/indexing, it seems to work fine with the one-dimensional array. However, when I try to apply it to the two-dimensional array, I encounter some compiler errors (Cannot convert value of type 'PartialRangeFrom' to expected argument type 'MLXArray'). Not sure if I have done anything wrong? Here is the code snippet to reproduce the error: