ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.37k stars 1.01k forks source link

Fix inconsistent example in "Automatic Vectorization" docs #1556

Closed chrisoffner3d closed 2 weeks ago

chrisoffner3d commented 2 weeks ago

Proposed changes

The axes over which the following examples in the documentation iterate strike me as strange:

xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))

def naive_add(xs, ys):
    return [xs[i] + ys[:, i] for i in range(xs.shape[1])]

Instead you can use vmap to automatically vectorize the addition:

# Vectorize over the second dimension of x and the
# first dimension of y
 vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))

The issue is as follows:

Maybe naive_add should also wrap the list comprehension in an mx.array call in order to be functionally equivalent to vmap_add, but I have not done this. I also corrected a minor typo in the docs elsewhere.