ml-explore / mlx-swift

Swift API for MLX
https://ml-explore.github.io/mlx-swift/
MIT License
683 stars 54 forks source link

[Documentation] Advanced indexing #52

Closed mzbac closed 7 months ago

mzbac commented 8 months ago

I tried to follow the documentation for advanced indexing to implement top_p as shown in https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py#L21-L25 in Python.

After following the documentation here: https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/indexing, it seems to work fine with the one-dimensional array. However, when I try to apply it to the two-dimensional array, I encounter some compiler errors (Cannot convert value of type 'PartialRangeFrom' to expected argument type 'MLXArray'). Not sure if I have done anything wrong? Here is the code snippet to reproduce the error:

let array = MLXRandom.randInt(0 ..< 100, [1, 5])

let sortIndexes = argSort(array, axis: -1)

let sorted = array[0..., sortIndexes]
davidkoski commented 8 months ago

The problem is the types of the arguments here:

let sortIndexes = argSort(array, axis: -1)
let sorted = array[0..., sortIndexes]

This is passing a RangeExpression (the 0...) and an MLXArray -- there is no subscript implementation that takes both of those, which is the error you are seeing here.

From the python code:

sorted_probs = probs[..., sorted_indices]

That isn't quite the same as what you have written here -- the ... (spread) operator will consume all the dimensions up to the last, mostly :-). It gets complicated because sorted_indices is an array itself and may be multidimensional, so it follows these rules:

I think it might be more along the lines of:

if there was a variant that took an MLXArray instead of a single Int.

For comparison here is the code that handles an MLXArray index:

and the python implementation side of this:

davidkoski commented 8 months ago

So what do we need here? I think we could do any of the following:

The latter is a lot more complicated and doesn't fit as nicely with Swift because of the relaxed types it would take. I think we can do better than subscript(indexes: Any...) using a protocol, but the relative lack of types will still be a bit unusual swift-wise.

So I am leaning toward the former.

davidkoski commented 8 months ago

Another good option is to use the take() function.

Here is the original python in the case where the logits are shaped like 1, 1, 12:

>>> sorted_probs = probs[..., sorted_indices]
>>> sorted_probs.shape
(1, 1, 1, 1, 12)

>>> sorted_probs
array([[[[[2.41196e-10, 1.78221e-09, 1.31688e-08, ..., 0.0158369, 0.11702, 0.864665]]]]], dtype=float32)

The shape of the result is a little bit surprising, but that is how numpy advanced indexing works.

The take() function can give us something similar (and in fact it is used in the case where we have a single array of indices):

>>> mx.take(probs, sorted_indices).shape
(1, 1, 12)

is close, but since we have the leading spread operator it would actually be:

>>> mx.take(probs, sorted_indices, -1).shape
(1, 1, 1, 1, 12)

Which can be written in swift like this:

take(probs, sortedIndices, axis: -1)

If that works for you I can update the docs to cover this case as well.

mzbac commented 8 months ago

@davidkoski, thank you for the comprehensive explanation. I did a quick test and the take worked well. It would be great if we could add this to the documentation. I may aslo create a PR in swift-mlx-example for the top_p implementation to demonstrate how it's done in a real example :)

davidkoski commented 8 months ago

Awesome, I will prep the docs!

mzbac commented 8 months ago

Hi @davidkoski, Just found out that we may still need to support advanced indexing using [..., mlxAtray]. For example, take works well for selecting mlx arrays from indexing, but we may need to use advanced indexing for setting items in the array. For example, in the mlx-lm's repetition penalty here: https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py#L114

davidkoski commented 8 months ago

Ah, interesting. Let me look at what it would take to get the whole python indexing implemented.

davidkoski commented 8 months ago

See #55 -- not done yet, but you can see what it looks like. Check out the tests to see it in use.

davidkoski commented 7 months ago

55 is merged -- ok to close?

mzbac commented 7 months ago

Thanks @davidkoski, this is awesome and really appreciate the effort. I will do some tests today and close the issue.