The axes over which the following examples in the documentation iterate strike me as strange:
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use vmap to automatically vectorize the addition:
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The issue is as follows:
xs has 4096 rows and 100 columns.
ys has 100 rows and 4096 columns.
naive_add iterates over range(xs.shape[1]), which is 100
Therefore, naive_add only adds the first 100 rows of xs to the first 100 columns of ys and ignores the remaining 3996 rows and columns, respectively.
The result is a list of 100 arrays with 100 entries each, i.e. equivalent to a (100, 100) array.
I think the intention was to add each row of xs to its corresponding column in ys, producing (the list equivalent of) a (4096, 100) array.
That's why naive_add should iterate over range(xs.shape[0]) instead.
Also, the existing definition of vmap_add, with in_axes=(1, 0) is currently not functionally equivalent to that of naive_add, since vmap_add(xs, ys).shape == (100, 4096).
Therefore, vmap_add should use in_axes=(0, 1) instead.
Maybe naive_add should also wrap the list comprehension in an mx.array call in order to be functionally equivalent to vmap_add, but I have not done this. I also corrected a minor typo in the docs elsewhere.
Proposed changes
The axes over which the following examples in the documentation iterate strike me as strange:
The issue is as follows:
xs
has 4096 rows and 100 columns.ys
has 100 rows and 4096 columns.naive_add
iterates overrange(xs.shape[1])
, which is 100naive_add
only adds the first 100 rows ofxs
to the first 100 columns ofys
and ignores the remaining 3996 rows and columns, respectively.(100, 100)
array.xs
to its corresponding column inys
, producing (the list equivalent of) a(4096, 100)
array.naive_add
should iterate overrange(xs.shape[0])
instead.vmap_add
, within_axes=(1, 0)
is currently not functionally equivalent to that ofnaive_add
, sincevmap_add(xs, ys).shape == (100, 4096)
.vmap_add
should usein_axes=(0, 1)
instead.Maybe
naive_add
should also wrap the list comprehension in anmx.array
call in order to be functionally equivalent tovmap_add
, but I have not done this. I also corrected a minor typo in the docs elsewhere.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes