ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] matmul yields different results when using concat #1082

Open muchi674 opened 1 week ago

muchi674 commented 1 week ago

Describe the bug matmul yields different result when multiplying vectors concatenated into the same matrix versus multiplying them separately

To Reproduce code:

import mlx.core as mx
import numpy as np

W, H = 2, 5

def test0():
    w = mx.random.uniform(-1, 1, (W, H), dtype=mx.float16)
    x0 = mx.random.uniform(-1, 1, (W,), dtype=mx.float16)
    x1 = mx.random.uniform(-1, 1, (W,), dtype=mx.float16)
    print(mx.array([x0 @ w, x1 @ w]))
    print(mx.array([x0, x1]) @ w)

def test1():
    np_w = np.random.uniform(-1, 1, (W, H)).astype(np.float16)
    x0 = np.random.uniform(-1, 1, (W,)).astype(np.float16)
    x1 = np.random.uniform(-1, 1, (W,)).astype(np.float16)
    print(np.array([x0 @ np_w, x1 @ np_w]))
    print(np.array([x0, x1]) @ np_w)

if __name__ == "__main__":
    print("mlx:")
    mx.random.seed(0)
    test0()

    print("numpy:")
    np.random.seed(0)
    test1()

output:

mlx:
array([[-0.256348, 0.953125, 0.179932, 0.740234, -0.149292],
       [0.234131, -0.737305, -0.0961914, -0.549805, 0.0778198]], dtype=float16)
array([[-0.256348, 0.953125, 0.179932, 0.740234, -0.149292],
       [0.234131, -0.737793, -0.0962524, -0.550293, 0.0778809]], dtype=float16)
numpy:
[[ 0.07385  0.2439   0.1653   0.10596 -0.1026 ]
 [ 0.2615  -0.04764  0.695    0.8013  -0.2192 ]]
[[ 0.07385  0.2439   0.1653   0.10596 -0.1026 ]
 [ 0.2615  -0.04764  0.695    0.8013  -0.2192 ]]

Expected behavior the last four numbers of mlx output should match in the two versions

Desktop (please complete the following information):

Additional context If this is the case, there should be plenty of issues in your mlx_lm library

awni commented 1 week ago

I don't think this is a bug but due to numerical differences as order of operations with finite precisions is not associative and the two versions you have could have different orders. The lower precision exacerbates the effect.

If you need them to match (or at least be a lot closer), use fp32. I tried it and I they were identical in that case.

I will let @jagrit06 comment on this before closing, to be sure. Also if you notice any instances with larger discrepancies that would be useful to share.