dvlab-research / Stratified-Transformer

Stratified Transformer for 3D Point Cloud Segmentation (CVPR 2022)
MIT License
376 stars 40 forks source link

AssertionError when testing the pre-trained model s3dis_model_best.pth on S3DIS #77

Closed jxl152 closed 1 year ago

jxl152 commented 1 year ago

Firstly, thanks for your work on the Stratified-Transformer.

When testing the pre-trained model s3dis_model_best.pth on S3DIS, the AssertionError of "assert (relative_position_index >= 0).all()" occurs. Since the pre-trained model was trained on S3DIS, running it on S3DIS should have been no problem.

The following four steps are what I have done in testing the pre-trained model on S3DIS.

1. Download S3DIS of the version Stanford3dDataset_v1.2_Aligned_Version

s3dis_google_drive

2. Preprocess the downloaded dataset to obtain .npy files

3. Modify the config file of config/s3dis/s3dis_stratified_transformer.yaml

Change the data_root, model_path, and save_foler. Except those, I did not change other settings.

DATA: .... data_root: dataset/stanford_indoor3d ...

TEST: ... model_path: pretrained/s3dis_model_best.pth save_folder: output/test/s3dis/stratified ...

4. Testing I ran the command: python test.py --config config/s3dis/s3dis_stratified_transformer.yaml

test_error

Then, there is the AssertionError. I know it is due to the cRPE mechanism.

Since the pre-trained model was trained on S3DIS, it would be no problem when tested on the same dataset.

Could you please point out in which step I am wrong? Many thanks!

jxl152 commented 1 year ago

By debugging, after "test 1/7, 1/67, 20/65, 80000/236790", there will be a case that relative_position_index has one negative value: relative_position_index[2842959] = tensor([-1,31,22]), as shown below.

jxl152 commented 1 year ago

More details. As shown in the previous comment, error_point_idx = 2842959. Then, xyz[index_0][error_point_idx] = [4.3200, 2.8800, 0.9320] xyz[index_1][error_point_idx] = [4.6400, 2.8850, 1.0260] , as shown in the below figures.

According to the #188 code, i.e., relative_position = xyz[index_0] - xyz[index_1]: relative_position[error_point_idx] = [-0.3200, -0.0050, -0.0940]

According to the #189 code, i.e., relative_position = torch.round(relative_position * 100000) / 100000: relative_position[error_point_idx] = [-0.3200, -0.0050, -0.0940]

Then, the #190 code, i.e., relative_position_index = (relative_position + 2 * self.window_size - 0.0001) // self.quant_size: relative_position_index[error_point_idx] = [-1, 31, 22], which will cause the AssertionError in #191 code.

jxl152 commented 1 year ago

I changed the code for computing relative_position_index from   relative_position_index = (relative_position + 2 * self.window_size - 0.0001) // self.quant_size to   *relative_position_index = (relative_position + 2 self.window_size) // self.quant_size**

*This time, the AssertionError is caused by the code: assert (relative_position_index <= 2self.quant_grid_length - 1).all()** , and error_point_idx = 5045428. Then, xyz[index_0][error_point_idx] = [4.6400, 2.8860, 1.1470] xyz[index_1][error_point_idx] = [4.3200, 2.8800, 1.1460] , as shown in the below figures.

According to the code, i.e., relative_position = xyz[index_0] - xyz[index_1]: relative_position[error_point_idx] = [0.3200, 0.0060, 0.0010]

According to the code, i.e., relative_position = torch.round(relative_position * 100000) / 100000: relative_position[error_point_idx] = [0.3200, 0.0060, 0.0010]

According to the modified code, i.e., relative_position_index = (relative_position + 2 self.window_size) // self.quant_size: relative_position_index[error_point_idx] = [64, 32, 32], of which the first element exceed the upper bound 63 computed by 2 self.quant_grid_length - 1. Thus, it will cause the AssertionError.

jxl152 commented 1 year ago

In my opinion, the code for computing the relative_position_index:     relative_position_index = (relative_position + 2 * self.window_size - 0.0001) // self.quant_size has a bug.

It can guarantee all the values of relative_position_index do not exceed the upper bound (63) but cannot ensure that any of them is less than the lower bound (0).

@X-Lai Do you agree with the above tests and analysis? Many thanks!

X-Lai commented 1 year ago

Thank you very much for pointing out this issue! You are correct. There is a bug.