getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.05k stars 64 forks source link

Tensordot problem #204

Open gahdritz opened 2 years ago

gahdritz commented 2 years ago

I'm trying to convert the following PyTorch code into performant KeOps that never manifests the O(N^2C^2) temporary activation in the middle:

# [*, N, N, C, C]
outer = torch.einsum("...bac,...dae->...bdce", a, b)

# [*, N, N, C * C]
outer = outer.reshape(outer.shape[:-2] + (-1,))

# [*, N, N, C_z]
outer = self.linear_out(outer)

At the start, a and b are both of shape [, N, M, C]. linear_out is a linear layer of shape [C C, C_z]. C_z << C * C.

I think I could do both the einsum and the reduction separately in KeOps. The einsum can be simulated by reshaping a and b to [*, N, 1, C, 1, M] and [*, 1, N, 1, C, M], respectively, before turning them into LazyTensors and running (a * b).sum(-1). The reduction could be accomplished in a similar way. Since the [*, N, N, C, C] LazyTensor output of the first operation can't be reshaped, however, I don't know how to combine the two. Any help would be appreciated.

jeanfeydy commented 2 years ago

Hi @gahdritz ,

Thanks for your interest in the library and apologies for the long time to answer - December has been pretty disrupted by Covid here in Paris.

To answer your question: could you tell us more about the dimensions N, M, C, and Cz? Depending on their size, the optimal way of implementing your computation may be very different. (An important point is that if M is small, you should probably pre-compute the matrix-vector product between “C” and “b” to obtain a “C x Cz” matrix that can then be multiplied with “a” efficiently. Generally speaking, expanding “dot(a, C @ b)” as “dot(C, a.outer(b))” as you do in the code above may not be great for performance. Also, if M, C and Cz are all larger than 64-128, it is likely that KeOps won’t be very helpful, unfortunately.)

Finally, I understand that you are mostly interested in a kind of sum reduction over the final tensor “outer”, i.e. in a [*, N, Cz] array to implement an attention layer? Could you tell us more about it?

Best regards, Jean

sebastienwood commented 2 years ago

Hi ! I have a similar problem at hand, could you resolve it your side @gahdritz ? @jeanfeydy I can give the details of my problem, would you prefer it here or in another ticket to keep things clean ?