ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] Bad result for GPU matmul for specific shape #1084

Closed awni closed 1 week ago

awni commented 1 week ago

Compare the following on the CPU / GPU. GPU gives 0/inf and the CPU looks more or less correct.

import mlx.core as mx
s1 = (4, 6479, 2048)
s2 = (256000, 2048)
a = mx.random.uniform(shape=s1)
b = mx.random.uniform(shape=s2)

mx.eval(a, b)

c = a @ b.T
awni commented 1 week ago

Causing NaN in LoRA training with Gemma: https://github.com/ml-explore/mlx-examples/issues/620