tatp22 / linformer-pytorch

My take on a practical implementation of Linformer for Pytorch.
https://arxiv.org/pdf/2006.04768.pdf
MIT License
400 stars 36 forks source link

Question: Is Linformer permutation equivariant (set-operation)? #26

Open nmakes opened 1 year ago

nmakes commented 1 year ago

Hi. Thanks for the wonderful implementation!

I was wondering if linformer can be used with any unordered set of tensors (or is it just sequence data?). Specifically, is linformer permutation equivariant?

I'm looking to apply linear attention on points in 3d space (e.g. a point cloud with ~100k points). Would linformer attention be meaningful?

(I'm concerned about the n -> k projection, which assumes the n points in some order if I understand correctly)

Thanks!

tatp22 commented 1 year ago

Hey @nmakes! Originally, when I used the linformer, I also used it for a similar task (unstructured data). I made a report on it, but what I found out is that it was as effective as other sparse attention models. So I think it should work :slightly_smiling_face:

nmakes commented 1 year ago

Hey @tatp22, thanks for the answer.

Interesting! Could you please give a little more intuition on why you think it worked (about what task it was, and if there are any caveats)? :)

I'm actually seeing a clear regression in my task. Your insights would be super useful!

Thanks!

tatp22 commented 1 year ago

Hey @nmakes!

My task was similar to yours. I had a 3d point cloud, with every data point representing the 3d point cloud at different points in time, with a vector of about 7 data points for each 3d point. My 3d points were sometimes batched together on a dimension (for example, I would group 5 points in time on the x dimension), so that I can get temporal information integrated into my prediction.

My task was to predict the future properties of this data cloud. The times were batched, for example, in one hour intervals, and I had to predict 72 hours in the future and see the results.

Why did I think it worked? This is because I think that the model learned the most important relations between points on its own. This is why I think that it isn't so important what kind of data is fed into the model, as the model will more often than not find the regression on it's own.

Let me know if you have any more questions!

nmakes commented 1 year ago

Hey @tatp22,

Thank you so much for the details! :)

Q1: Just to clarify, did you apply attention for each point independently over its own 5-previous timesteps? Or was the attention applied over other points as well (e.g. Nx5 queries)?

It does makes sense to apply attention over past timesteps for each point independently in your example, where the task is to predict future timesteps for that particular point. But, referring to my earlier question, I'm trying to understand why linformer attention would work on unordered points.

Here's a small experiment I did. TL;DR: Changing the order of points, changes the outputs of the transformer:

from linformer_pytorch import Linformer
import torch

model = Linformer(
        input_size=5, # Dimension 1 of the input
        channels=3, # Dimension 2 of the input
        dim_d=3, # The inner dimension of the attention heads
        dim_k=3, # The second dimension of the P_bar matrix from the paper
        dim_ff=3, # Dimension in the feed forward network
        dropout_ff=0.15, # Dropout for feed forward network
        nhead=6, # Number of attention heads
        depth=2, # How many times to run the model
        dropout=0.1, # How much dropout to apply to P_bar after softmax
        activation="gelu", # What activation to use. Currently, only gelu and relu supported, and only on ff network.
        checkpoint_level="C2", # What checkpoint level to use. For more information, see below.
        ).cuda().eval()

Suppose we have point cloud with 5 3d points:

>>> x = torch.randn(1, 5, 3).cuda()
tensor([[[ 2.5748,  0.9807,  2.6821],
         [-0.4248,  0.6271, -0.9472],
         [-0.4336, -1.2144,  0.9712],
         [ 1.3365,  0.0667,  0.0718],
         [ 0.4151, -0.6590,  0.2932]]], device='cuda:0')
>>> y = model(x)
tensor([[[ 0.7686, -1.4124,  0.6437],
         [-0.1116,  1.2767, -1.1651],
         [ 0.0729, -1.2596,  1.1867],
         [ 1.4137, -0.6734, -0.7402],
         [ 0.8355, -1.4059,  0.5704]]], device='cuda:0')

Now, we swap the 0th and 4th index points in x:

>>> x2 = x.clone()
>>> x2[:, 0] = x[:, 4]
>>> x2[:, 4] = x[:, 0]
>>> print(x2)
tensor([[[ 0.4151, -0.6590,  0.2932],
         [-0.4248,  0.6271, -0.9472],
         [-0.4336, -1.2144,  0.9712],
         [ 1.3365,  0.0667,  0.0718],
         [ 2.5748,  0.9807,  2.6821]]], device='cuda:0')

Note, we only swapped the first and the last tensors. The point cloud remains the same, however, passing it through the transformer changes the features, even for the points that were not swapped (idx=1 to idx=3).

>>> y2 = model(x2)
tensor([[[ 0.7401, -1.4137,  0.6735],
         [-0.1346,  1.2865, -1.1519],
         [-0.0927, -1.1758,  1.2685],
         [ 1.4140, -0.6844, -0.7296],
         [ 0.2472, -1.3295,  1.0823]]], device='cuda:0')

This is why I'm finding it a little hard to understand how to make Linformer work for unstructured data.

Q2: Did you mean that even under such behavior, Linformer is expected to improve the representations for the task? If so, how do we handle inference where the ordering can be random (different results for the same scene based on how input is fed each time?).

PS: With the same code, setting full_attention=True during model init works as expected - the transformed points are the same for the corresponding points in x and x2. The challenge is to have this permutation equivariant property for linear attention.

tatp22 commented 1 year ago

Ah, ok, I understand your points now. To give you an answer, I did the second, Nx5 version, so there were a lot of points! As you might probably have guessed, normal attention would be too big, so I resorted to sparse attention as it helped me there.

Q1: See #15 for more information about this. TLDR, yes the internal downsampling does scatter the data around, so this property is not guaranteed. I am not sure if it would work for your task, but have you tried encoding positional data into the model? Perhaps with my other repository? https://github.com/tatp22/multidim-positional-encoding :slightly_smiling_face:

But I think that achieving this equivariance property is (I think) hard, if not impossible with linear attention, because whatever method you choose to use, I think that there will be some information that is necessarily lost with whatever downsampling method you use. What's nice about attention is that you compare all the information of every point with every other point, which is why I think equivariance is possible. Unless you keep that guarantee with linear attention, which this repo doesn't due to downsampling, then it is gone.

(ps: try setting k=n. You might get equivariance then, depending on the sampling method!)

Q2: Yes, It should! I think that the power here comes from the fact that there are so many parameters in the model that the linformer learns about the relationships anyways, due to the Q and V matrices holding redundant information. While learning, if you put in these points in a different order, I think that the model should still be powerful enough to see relationships due to the sheer number of params.

I hope this helps!