ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] mlx crashes with msg - uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64 #1076

Closed lkarthee closed 5 days ago

lkarthee commented 2 weeks ago

Describe the bug A clear and concise description of what the bug is.

To Reproduce

Include code snippet

import numpy as np
import mlx.core as mx
from keras.src.ops import core
indices = np.array([[1], [3], [4], [7]])
values = np.array([9, 10, 11, 12])
from keras.src import backend
backend.backend()
# >>> 'mlx'
x = core.scatter(indices, values, (8,))
x
# libc++abi: terminating due to uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64
zsh: abort      python
# keras.ops.scatter for mlx backend
def scatter(indices, values, shape):
    indices = convert_to_tensor(indices)
    values = convert_to_tensor(values)
    zeros = mx.zeros(shape, dtype=values.dtype)
    indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
    zeros = zeros.at[indices].add(values)

    return zeros

Expected behavior Mlx should not crash - it should throw an exception or error.

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

awni commented 2 weeks ago

If you are just asking for a catchable exception then #1077 should close this. We would like to eventually allow int64 and other 8 byte types to work with scatter, but that is more involved.

lkarthee commented 2 weeks ago

Thank you Awni. some observations:

zeros = mx.zeros(shape, dtype=values.dtype) 
zeros = zeros.at[indices].add(values) 

i tried this and it does not work as add does not take device kw_arg:

if zeros.dtype in [mx.int64, mx.uint64] and mx.get_default_device == mx.DeviceType.gpu :
  device = mx.Device(type=mx.DeviceType.cpu)
  zeros = zeros.at[indices].add(values, device=device) 
else:
  zeros = zeros.at[indices].add(values)

It would be helpful if mlx can fallback to cpu for scatter ops which are not supported on gpu or allow device kw_arg for all scatter ops.

Additional ops which are impacted by this bug:

awni commented 5 days ago

crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?

I improved the message in #1077. The problem is with the values.

scatter ops can use cpu device for int64 and uint64 ?

We prefer not to silently route to the CPU for ops without a GPU back-end. You can do this in the API by changing the default stream to the CPU before calling the scatter when the dytpe is int64/uint64.

are there any other ops which are not supported on gpu and run on cpu ?

Just a few. FFT and some of the lapack ops (QR / Inverse). Metal support for FFT is coming soon in #981 .

i tried this and it does not work as add does not take device kw_arg:

You can use a context manager. For most free ops stream kwarg also works. E.g.

v = mx.array([1, 2, 3])
u = mx.array([1, 2])
idx = mx.array([0, 1])

with mx.stream(mx.cpu):
    out = v.at[idx].add(u)
lkarthee commented 5 days ago

Thank you @awni for the fix.