octree-nn / octformer

OctFormer: Octree-based Transformers for 3D Point Clouds
MIT License
245 stars 17 forks source link

OctFormer performs worse with deeper Octree #25

Closed Grifent closed 5 months ago

Grifent commented 5 months ago

Thank you for this wonderful contribution.

I'm currently attempting to train OctFormer for the downstream task of lidar place recognition on a dataset not used in your repo (Oxford RobotCar dataset). I have kept the network layout of OctFormer essentially the same, but added a feature aggregator at the end of the network, and have tried both using the output features of the final stage, or in a FPN structure as you used for semantic segmentation. I have managed to get OctFormer performing reasonably well at this task with both these configurations, but I have found that performance degrades rapidly when using an input Octree depth of 7 or higher. This doesn't make sense to me, as in my mind, a deeper Octree corresponds to a higher resolution input.

I was wondering if you may have any idea what could cause this performance drop? The dataset in question is somewhat unique, as I am using a pre-processed benchmark that has already been downsampled to 4096 points and normalised to [-1, 1] range (I know using only 4096 points defeats the purpose of OctFormer's efficiency, but I must start with this benchmark as it is the most common for this task). The original dataset has ~22,000 outdoor point clouds in the training set with typical max dimensions of [-30m, 30m]. The downsampling done in pre-processing uses an average voxel size of 0.3m, so my assumption is that an Octree depth of 8 should sufficiently represent this resolution (60m / 0.3m = 200 voxels per dimension, and 2^8 = 256 octants per dimension). This is on par with CNN-based methods which typically quantize these point clouds to 200 voxels per dimension.

I believe I am building the Octrees correctly, as I am first performing data augmentation (random rotation, scaling, cropping, etc), then clipping point clouds to [-1, 1], then constructing the Octree and calling octree.build_octree(), then collating batches with ocnn.octree.merge_octrees() and finally calling octree.construct_all_neigh(). The dataset only contains point position information, so for input features I am using the 'P' option for global position in ocnn.modules.InputFeature().

Is there any common reason why deeper Octrees would be performing worse? I could understand that very deep Octrees may degrade performance if going deeper doesn't increase the effective resolution any further, however I have found that the number of non empty octants after creating an Octree on this dataset typically increases until depth 8/9 and then plateaus and hardly increases with deeper Octrees, which is as expected from my earlier calculations.

For reference, here are the training loss curves and recall@1 on the test set with different Octree depths.

Thanks.

image image

wang-ps commented 5 months ago

I was wondering if you may have any idea what could cause this performance drop?

If the octree depth increases, the stages of the network should also be increased accordingly. For example, suppose that a neural network takes an image of resolution 256 as input and produces a feature map with resolution 16, if the image resolution increases to 1024, the network will produce a feature map with resolution 128, and probably the performance may change. For octrees, it is the same.

Is there any common reason why deeper Octrees would be performing worse?

I think if the stage number increases according to the depth, the performance may not change too much and possibly increase. On ScanNet segmentation, I have verified that the performance of octrees with depth 11 is better than octrees with depth 10.

Grifent commented 5 months ago

Ah thank you. I hadn't thought the solution would be that simple, but it seems that adding an extra stage or two is helping performance greatly.

I had another related question though. Is there a methodology behind the number of OctFormer blocks in each stage following [2, 2, 18, 2]? It seems this network structure is quite common among hierarchical transformers, but have you tested any other configurations? I'm curious whether there is a reason behind giving significantly more parameters to the second-to-last stage of the network, as opposed to having a constant or gradually increasing number of blocks in the network.

wang-ps commented 5 months ago

why the stage number following [2, 2, 18, 2]?

Yes, as you mentioned, this network structure is quite common among hierarchical transformers. I followed this practice in the research on vision transformers. And I have not tuned these configurations too much since I do not have too many GPU resources. According to my preliminary experiments, the performance is rather robust to these configurations.

Grifent commented 5 months ago

I see, thank you for the insights.