jadie1 / Point2SSM

Point2SSM: Learning Morphological Variations of Anatomies from Point Cloud
https://arxiv.org/abs/2305.14486
8 stars 2 forks source link

High Time Complexity when number of output point increase #5

Closed Starry-lei closed 4 months ago

Starry-lei commented 4 months ago

Hi Jadie,

Thank you for your great work. I've noticed that when I increase the number of output points to 2048 and use a larger dataset with about 800 samples, the training becomes quite slow. Do you have any suggestions for getting more correspondences like 2048 or more points on a larger dataset?

I know the time complexity of cross-attention is O(n**2), do you think sparse attention will help?

Best, Lei

jadie1 commented 4 months ago

Hi Lei,

Yes the attention module is definitely the bottleneck in terms of memory and time complexity. You can try reducing the size of the input feature vector L or the dimensions of the intermediate SFA blocks in the attention module. Sparse attention may also help. I didn't perform a ton of ablations with different attention module architectures, but performance did seem fairly robust to changes in dimension size.

Jadie

Starry-lei commented 4 months ago

Thanks for your prompt reply, I modified the attention to an memory-efficient attention from Xformer and use Mix Precision training, these two engineering modifications indeed help to train 2048 points faster and seems a little improve the best chamfer distance metric. If you allow me, I can contribute these two modifications to your Point2SSM and Point2SSM++.

I really like your SSM works :+1: Best, Lei

jadie1 commented 4 months ago

That's great! Yeah if you make a pull request I'd be happy to add your contributions.

Starry-lei commented 4 months ago

Delighted to contribute to your work!