Open shoyer opened 2 years ago
@shoyer, I've been thinking about this quite extensively in the past, and probably I'm settled about how that would like. Not expecting everyone to like it though.
I've put some relatively simple implementation and examples here, would be nice to get your thoughts https://github.com/arogozhnikov/einops/blob/master/einops/experimental/indexing.py
Wanted this to be my first python-array-api based function, but found out indexing isn't really supported by standard.
Some syntax ideas here: https://github.com/mcabbott/Tullio.jl
From their README:
Tullio is a very flexible einsum macro. It understands many array operations written in index notation -- not just matrix multiplication and permutations, but also convolutions, stencils, scatter/gather, and broadcasting
@tullio M[x,y,c] := N[x+i, y+j,c] * K[i,j] # sum over i,j, and create M
@tullio S[x] = P[x,y] * log(Q[x,y] / R[y]) # sum over y, and write into S
@tullio A[i,j] += B[i,k,l] * C[l,j] * D[k,j] # sum over k,l, and add to values in A
@tullio (*) Z[j] := X[ind[k],j] * exp(-Y[k]) # product over k
Julia's line-level macros really shine for this kind of stuff
@tullio M[x,y,c] := N[x+i, y+j,c] * K[i,j] # sum over i,j, and create M
Here is the problem with arbitrary expressions in indexers (aside from implementation complexity): they will be immediately used for conv-style operations (which would work slower and with much larger memory footprint than cudnn) and immediately fall into out-of-bounds or negative indices. I don't see a way to meet variable user's expectations for 'reasonable' out-of-bound processing.
Path with indexing or maybe indexing+reduction looks feasible.
I realize that indexing proposal above is a bit extraterrestrial at first, but only until you couple indexing with the second part of proposal (how those indices should be computed):
# for every timeframe in a video, find the token with the highest norm (across h and w), and compose a new stack of them
norm_bthw = x_bthwc.norm(dim=-1)
# here you explicitly say which axes argmax should be taken on, and the shape of output is readable - 2 x b x t
indices_2bt = argmax(norm_bthw, 'b t h w -> [h, w] b t')
# note that '[h, w] b t' part just migrated from the previous operation
selected_embeddings_btc = einindex('b t c <- b t h w c, [h, w] b t', x_bthwc, indices_2bt)
AFAIK, multidim argmax / topk + indexing are not solved in numpy and existing frameworks, and above looks like a quite consistent solution to me
this would be huge! you have no idea the needless complexity i have written up in the past https://github.com/lucidrains/point-transformer-pytorch/blob/main/point_transformer_pytorch/point_transformer_pytorch.py#L13 lol
@arogozhnikov what would it take for you to build this out to your heart's content?
you are the only one in the world who can do this justice, imo
This would be fantastic, I love Julia but still in production python is rather required in most companies, still getting this in einops would be stellar!
i'm convinced that works of art like einops can't be extrinsically motivated into existence, but if Alex wants to put up a Patreon, would be glad to become a patron in the short term, with no obligation on his end. "greatness cannot be planned"
oh. my. god. https://github.com/arogozhnikov/eindex it's happening
Can einops add an argmax/argmin in the reduce function?
Like this:
einops.reduce(tensor, "i j -> i", 'argmax')
Or even directly call custom functions:
einops.reduce(tensor, "i j -> i", torch.argmax)
@Bit0r https://github.com/arogozhnikov/eindex/blob/main/tutorial/tutorial.ipynb
I know this library, but theoretically argmax/argmin
should also be a reduce
operation. A unified reduce
API can be used more conveniently.
argmin/argmax are not reductions
As suggested on Twitter by @colah: https://twitter.com/ch402/status/1539774943214178304
There are a few more syntax ideas in the Twitter thread. I'm not entirely sure what this could look like, but I agree that array indexing syntax is one of the hardest parts of NumPy that isn't already served by Einops.