ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[FEATURE REQUEST] mx.grad doesn't alias argnums and argnames #1072

Closed c0g closed 1 week ago

c0g commented 2 weeks ago

Not a bug, a nice-to-have-or-at-least-better-error message-please kind of thing.

If you differentiate a function, you can't pass an argument in as a kwarg, even if the fn is differentiated w.r.t to that kwargs argnum:

import mlx.core as mx
import mlx.nn as mn
a = mx.array([1., 2, 3, 4])
xent_grad = mx.grad(mn.losses.cross_entropy)
xent_grad(logits=a, targets=mx.array(0))
# ^- doesn't work :-(((
# ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 arguments.

xent_grad(a, targets=mx.array(0)) or calling xent_grad = mx.grad(mn.losses.cross_entropy, argnames='logits') are both fine of course.

An ergonomic solution would be to let me do either -- e.g.

xent_grad = mx.grad(mn.losses.cross_entropy, argnums=0, argnames='logits')
xent_grad(logits=a, targets=mx.array(0))

This doesn't work atm, even though argnum=0 and argnames='logits' are the same thing.

awni commented 2 weeks ago

I think a good first step is to improve the error message. Even just adding the word positional would probably help a bit, as in:

# ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 positional arguments.

What the right thing to do for args / kwargs is somewhat ambiguous. I can see an argument for being more relaxed. I can also see an argument for matching the argnums/argnames to the function parameters as we do now.

IamShubhamGupto commented 2 weeks ago

@awni I'd like to start working on core mlx issues but I think I can start with this first.

I think a good first step is to improve the error message. Even just adding the word positional would probably help a bit, as in:

# ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 positional arguments.

This works for the above message but what would be a more general way to improve error messages across the package? should we try matching torch's error messages? that's tried and tested

awni commented 2 weeks ago

This works for the above message but what would be a more general way to improve error messages across the package?

For now, let's take them on a one-by-one basis. If you see something that you think should be improved, feel free to file an issue and we can discuss. For this issue, I would just look at the grad error message.

IamShubhamGupto commented 1 week ago

This works for the above message but what would be a more general way to improve error messages across the package?

For now, let's take them on a one-by-one basis. If you see something that you think should be improved, feel free to file an issue and we can discuss. For this issue, I would just look at the grad error message.

alright then, lets do that! Seeing theres no existing PR for this, I'd like to be assigned. Ill make a PR shortly

IamShubhamGupto commented 1 week ago

UPDATE it does what you expect it to do

(mlx-dev) shubham@Shubhams-MBP mlx % python dummy.py                                   
Traceback (most recent call last):
  File "/Users/shubham/Documents/workspace/forks/mlx/dummy.py", line 5, in <module>
    xent_grad(logits=a, targets=mx.array(0))
ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 positional arguments.

If we make all the arguments have a default value and then just check if the user has provided the value apart from default, that should allow accepting positional arguments right?

def cross_entropy(
    logits: mx.array = None,
    targets: mx.array = None,
    weights: Optional[mx.array] = None,
    axis: Optional[int] = -1,
    label_smoothing: Optional[float] = 0.0,
    reduction: Optional[Reduction] = "none",
) -> mx.array:
    if logits is None or targets is None:
        raise ValueError("Both logits and targets must be provided and cannot be None.")