lucidrains / point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
MIT License
592 stars 58 forks source link

Cost too much memory #5

Open JLU-Neal opened 3 years ago

JLU-Neal commented 3 years ago

I'm not sure whether I used the point-transformer correctly: I just implemented one block for training, and the data shape of (x, pos) in each gpu are both [16, 2048, 3], later I was informed that my gpu is running out of the memory(11.77 GB total capacity)

lucidrains commented 3 years ago

Oh yup, this type of vector attention is quite expensive. In the paper, they did KNN on each point, and only attended to the local neighbors

lucidrains commented 3 years ago

You can use this great library to do the clustering https://github.com/rusty1s/pytorch_cluster

lucidrains commented 3 years ago

@JLU-Neal Hi, I decided to add in the feature for attending only to the k nearest neighbors, which can be set with the num_neighbors keyword argument. Let me know if it works for you! https://github.com/lucidrains/point-transformer-pytorch/commit/2ec1322cbecced477826652c567ed8bc2d31c952

JLU-Neal commented 3 years ago

@JLU-Neal Hi, I decided to add in the feature for attending only to the k nearest neighbors, which can be set with the num_neighbors keyword argument. Let me know if it works for you! 2ec1322

Wow, surprised by your efficient work! Wish you a happy Chinese new year!

zimonitrome commented 3 years ago

kNN is only specified for the "transition down" layer. They don't seem to mention it for the general point transformer layer. So is this just an added bonus or am I missing something from the original paper?

JLU-Neal commented 3 years ago

kNN is only specified for the "transition down" layer. They don't seem to mention it for the general point transformer layer. So is this just an added bonus or am I missing something from the original paper?

In the section 3.2 Point Transformer Layer, author mentioned that "the subset X(i) ∈ X is a set of points in a local neighborhood (specifically k nearest neighbors) of xi."

ouenal commented 3 years ago

kNN is only specified for the "transition down" layer. They don't seem to mention it for the general point transformer layer. So is this just an added bonus or am I missing something from the original paper?

In the section 3.2 Point Transformer Layer, author mentioned that "the subset X(i) ∈ X is a set of points in a local neighborhood (specifically k nearest neighbors) of xi."

On that note, do you think they talk about k-nearest neighbours in the point space? They always refer to the coordinates as p and not X throughout the paper. I've always read that as the KNN in the feature space which although might be less stable, may increase the receptive field quite a bit.

L-Reichardt commented 2 years ago

@lucidrains I have the same issue, and cannot get a single layer to run with 12GB GPU Memory. Maybe my understanding of the layer is incorrect but I have made the following observation:

I have a pointcloud in (160000, 3) Format which I put into the layer as both feature and position with nearest neighbors k=8. However I noticed (using smaller input data) that in the forward function the relative position takes up almost pos**2 in memory and does not change with k. Is the implementation here correct?

rel_pos = pos[:, :, None, :] - pos[:, None, :, :]

kidpaul94 commented 1 year ago

@lucidrains I have the same issue, and cannot get a single layer to run with 12GB GPU Memory. Maybe my understanding of the layer is incorrect but I have made the following observation:

I have a pointcloud in (160000, 3) Format which I put into the layer as both feature and position with nearest neighbors k=8. However I noticed (using smaller input data) that in the forward function the relative position takes up almost pos**2 in memory and does not change with k. Is the implementation here correct?

rel_pos = pos[:, :, None, :] - pos[:, None, :, :]

I believe this code is designed to support both global and local attention mechanisms by switching knn on and off. The following code line checks whether knn is on and starts to select specific qk_rel, v, and position embedding based on indices from knn.

https://github.com/lucidrains/point-transformer-pytorch/blob/f9d4e56a26ceee70deb60da230fef40c656396e6/point_transformer_pytorch/point_transformer_pytorch.py#L78