ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 856 forks source link

[BUG] Passing `axis=None` into `argpartition` causes `TypeError` #1121

Closed hegyibalint closed 2 weeks ago

hegyibalint commented 2 weeks ago

Describe the bug

The documentation states that the argpartition method's axis parameter can be None; this means that the function will partition the array as a flattened array.

Instead, when None is passed, a TypeError will be raised:

TypeError                                 Traceback (most recent call last)
Cell In[7], [line 2](vscode-notebook-cell:?execution_count=7&line=2)
      [1](vscode-notebook-cell:?execution_count=7&line=1) a = mx.ones((3,3))
----> [2](vscode-notebook-cell:?execution_count=7&line=2) mx.argpartition(a, 2, None)

TypeError: argpartition(): incompatible function arguments. The following argument types are supported:
    1. argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array

Invoked with types: mlx.core.array, int, NoneType

I believe that the error message also agrees with the documentation and usage, as it says Union[None, int].

To Reproduce

import mlx.core as mx

a = mx.ones((3,3))
mx.argpartition(a, 2, None)

Expected behavior Either None should not accepted and the documentation updated, or the code fixed.

Desktop (please complete the following information):

Additional context I've tried to look at the relevant binding code's lambda, and it makes sense to me[^1]; it's pretty straightforward how it tries to handle a None reference. I'm at a loss trying to figure out a fix.

[^1]: that being said, I'm not very proficient with binding Python and C++

PRESIDENT810 commented 2 weeks ago

It's caused by a small typo I think: "axis"_a = -1, should be "axis"_a.none() = -1, to allow None value.

Let me open a PR and fix it.