HuguesTHOMAS / KPConv-PyTorch

Kernel Point Convolution implemented in PyTorch
MIT License
782 stars 155 forks source link

Toronto3D process problems #156

Open GeoSur opened 2 years ago

GeoSur commented 2 years ago

Hello mr.Thomas.Thank you for your wonderful work!!And I have met a new question: I have modified this dataset as the same as s3dis, but when i run the triaing process, it raised "ERROR: It seems that the calibration have not reached convergence. Here are some plot to understand why:If you notice unstability, reduce the expected_N valueIf convergece is too slow, increase the expected_N value" I carefully compared my modified data with s3dis, is this because the xyz coordinate value of my input data is too large?((I have reduced the xyz coordinate value as the dataset instructions, but they are still greater than 100)

Another question: How to select an area for validation? Simply change the validation_split in s3dis?

HuguesTHOMAS commented 2 years ago

it raised "ERROR: It seems that the calibration have not reached convergence. Here are some plot to understand why:If you notice unstability, reduce the expected_N value If convergece is too slow, increase the expected_N value"

This error happens because the number of points per batch is too big or too small. Their should be plots appearing when this error happens, could you show them here?

is this because the xyz coordinate value of my input data is too large?

No, it should be because of the number of points per batch. The in_radius and first_subsampling_dl parameters control the number of points per sphere. And the batch_num parameter controls the number of spheres per batch.

Another question: How to select an area for validation? Simply change the validation_split in s3dis?

yes indeed

GeoSur commented 2 years ago

This error happens because the number of points per batch is too big or too small. Their should be plots appearing when this error happens, could you show them here?

Thanks for your reply! Here are the results:

截屏2022-03-23 09 01 36 截屏2022-03-23 09 06 35
GeoSur commented 2 years ago

The first photo shows the dataset which I have modified type as s3dis

HuguesTHOMAS commented 2 years ago

Ok from what I see, during calibration, the batch limit (which is the maximum of points allowed in a batch) is very low around 7000, where it should be between 50000 and 100000. This can be because of two reasons in my opinion:

In case this is about density, you have two ways to solve the issue:

  1. Increase in_radius, which will automatically increase the number of points per batch. This is also a good thing if you have large objects like cars to classify, the network will have a larger volume to process. Also I assume your dataset has varying densities like Semantic3D for example, so you want you first_subsampling_dl to still be small, but the general rule for this is parameters is to be small enough to capture details about the shape of the smallest object you are trying to classify. If you are only interested in cars and pedestrians, you can make it larger like 0.12m for example. But if you are interested in smaller objects like bollards, low vegetation, or road markings, you should probably keep it small like between 0.03 and 0.05.

  2. If you are happy with the size of your spheres and the subsampling grid, but still have a low number of points per sphere, you can simply increase the number of spheres per batch with the parameter batch_num. This is the most simple way to control the number of points per batch and optimize the network memory consumption on your GPU. Generally speaking, you can go as high as 100000 points per batch with a good GPU (12 GB), even higher if you have a bigger one. But this depends on a lot of thing, so just look at what happens in your case.

Best, Hugues

GeoSur commented 2 years ago

Also I assume your dataset has varying densities like Semantic3D for example

As you assumed this dataset is a mobile LiDAR dataset too, so I will modify this code in the two ways you mentioned.

Increase in_radius and first_subsampling_dl

Actually, I am not very certain about the logic of setting this parameter in_radius, how does it correlate with the batch limit, or can it be adjusted by multiples of the batch limit?

Increase the number of batch_num

And I will also try to increase the number of batch_num with a good GPU, and compare the results of these two modification methods.

Thaks a lot for your patient and professional reply!!!

GeoSur commented 2 years ago

After my attempts, I modified in_radius and first_subsampling_dl as the photo shows, this trainer runed successfully, but few minutes later, this error raised:

截屏2022-03-24 10 57 02

I assume this error is happening in the validation process, but I do not know why. Do you have any advice?

GeoSur commented 2 years ago

Is this because I modified the .ply file of "Toronto" dataset directly instead of pre-processing it as a .txt file? The original format of this dataset is .ply files, so I just modified them as s3dis, Maybe I caused this error? How to solve it... Another diffidence is that Toronto3d has an "unlabeled" class.

HuguesTHOMAS commented 2 years ago

Actually, I am not very certain about the logic of setting this parameter in_radius, how does it correlate with the batch limit, or can it be adjusted by multiples of the batch limit?

in_radius is the radius of the input spheres that the network processes. Bigger spheres mean larger portions of the dataset are fed to the network, it also means more points and thus more memory consumption.

batch_num is the number of distinct spheres you want in your batch. It is actually an average batch size, as explained on page 11 of our paper.

The two parameters are not correlated, but together, they will determine the total number of points your batch contains (sum of the number of points of each sphere).

After my attempts, I modified in_radius and first_subsampling_dl as the photo shows, this trainer ran successfully, but a few minutes later, this error raised:

It seems that you have an empty truth vector: The truth values are stored in a 0D array.

In trainer.py just before line 539, try to print the shape of truth and preds, if both are [0], then it means you have an empty input sphere

HuguesTHOMAS commented 2 years ago

I try to print the shape of "true", and I find there is "0" at the last line. which factor may cause this issue?

Are there non zero values for this shape before the 0 happens? If yes, then there is a problem when selecting a particular input sphere, if no, then their is a problem with the whole validation split

GeoSur commented 2 years ago

In trainer.py just before line 539, try to print the shape of truth and preds, if both are [0], then it means you have an empty input sphere

I have tried to print the shape of "truth" and I find "0" in the last line, and it is not always "0". have tested this error several times and found that it occurs basically in the first three times in validation processes, but it seems to happen randomly.

What should I do to solve it? Is this because of the problem with the calibration process?

HuguesTHOMAS commented 2 years ago

I have tried to print the shape of "truth" and I find "0" in the last line, and it is not always "0". have tested this error several times and found that it occurs basically in the first three times in validation processes, but it seems to happen randomly.

Ok so this confirms that the error happens when you select a particular area of the validation point cloud. The sphere are picked randomly, so this is why it does not happen always at the same time.

What you can do:

If you go back to trainer.py and follow where truth comes from, it is stored here:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/utils/trainer.py#L507-L509

Place a if statement to track the error and print everything you can to understand the error:


if target.shape[0] < 1:
    print('batch_ind', b_i)
    print('length', length)
    print('probs shape', probs.shape)
    print('point indices', inds)
    print('cloud index', c_i )

you can even print the value of the points themselves to know where in the dataset the sphere was:


if target.shape[0] < 1:

    in_pts = batch.points[0].cpu().numpy()
    in_pts = in_pts[i0:i0 + length]

    print('in_pts shape', in_pts .shape)
    print('in_pts mean', np.mean(in_pts, axis=0))

Depending on what is shown in these message, you should find the solution to your problem

GeoSur commented 2 years ago

Thank you for your patient reply!!I even suspected that it was caused by too much difference between this data and s3dis, I had also considered switching to the kitti form, I am now going to locate the problem as you said. Hope I can solve it.

SC-shendazt commented 2 years ago

Hello mr.Thomas.Thank you for your wonderful work!!And I have met a question: self.all_splits=[0,1,2,3,4,5,6,7,8,9] if my validation is:3,5,6 how can i set the :sel.validation_split =[?]

HuguesTHOMAS commented 2 years ago

Hi @SC-shendazt,

the code only handles one validation split in S3DIS.py, but it is very easy to modify:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/datasets/S3DIS.py#L140-L157

You can define self.validation_splits = [3,5,6] and change

self.all_splits[i] (!=) == self.validation_split

by something like

self.all_splits[i] (not) in self.validation_splits

The only other place where validation_split is used is:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/utils/trainer.py#L431-L433

But you can change it with something like

for val_split_i in val_loader.dataset.validation_splits:
    if val_split_i not in val_loader.dataset.all_splits: 
        return 

This is just a safe check

SC-shendazt commented 2 years ago

Thank you very much. I'll have a try first

GeoSur commented 2 years ago

Hello, Mr.Thomas! I tried to place these codes below the

predictions.append (probs)
targets.append (target)
10 += length

like the photo shows, but it printed nothing, and I deleted the "if statement", it printed these lines: 2022-03-25 10-45-41屏幕截图 2022-03-25 10-46-14屏幕截图 Maybe I'm retarded, and I don't find a solution from the messages. By the way, is this because of the differences (like the density of points) between toronto3d and s3dis?

HuguesTHOMAS commented 2 years ago

mmmh sorry there is a mistake here. The line

i0 += length

should be after these lines:

    in_pts = batch.points[0].cpu().numpy()
    in_pts = in_pts[i0:i0 + length]

    print('in_pts shape', in_pts .shape)
    print('in_pts mean', np.mean(in_pts, axis=0))

Or this does not work

GeoSur commented 2 years ago

Thank you for your reply! I changed it but it still did not work, so I deleted the "if statement". The print message is as the photo shows: 2022-03-26 11-07-20屏幕截图 2022-03-26 11-23-30屏幕截图 I guess this is because the parameter of in_radius or first_subsampling_dl, the input spheres of the validation process may be empty sometimes. So maybe I could change the parameters to solve this problem? But this is hard work because of the memory consumption and the density of dataset.

SC-shendazt commented 2 years ago

@GeoSur “”in_radius or first_subsampling_dl “” Perhaps you could try different values for these two parameters eg:in_radius=16 first_subsampling_dl=0.2~0.X

GeoSur commented 2 years ago

@SC-shendazt Thank you!! Actually i am trying to do this, but the gpu_memory of this computer is not large enough. so i have to try it on a more powerful device later...

HuguesTHOMAS commented 2 years ago

If the error is caused by an empty input sphere, you could probably place a safecheck in the function that gets these spheres and create the batch:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/datasets/S3DIS.py#L231

I'll let you investigate that.

GeoSur commented 2 years ago

@HuguesTHOMAS Thanks for your reply!!! I also think I should find my own solution to this problem! And thanks for your time!! I really appreciate it!!! Well another question: Could this Net work on the dataset which only has XYZ and true_class values?

SC-shendazt commented 2 years ago

QQ图片20220328181827 QQ图片20220328181855 Hello mr.Thomas.,And I have met a new question,Could you give me some advice?

HuguesTHOMAS commented 2 years ago

Well, this seems pretty clear, your labels seem to have float64 dtype instead of integer. You should be able to solve this yourself by following the code where the labels are defined

SC-shendazt commented 2 years ago

Thanks for your reply.!

GeoSur commented 2 years ago

@HuguesTHOMAS Hello Thomas!It is me still..

Could this Net work on the dataset which only has XYZ and true_class values(without RGB features compared withS3DIS)?

HuguesTHOMAS commented 2 years ago

Sure, you just have to change the number of features to 1 in the configuration here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/train_S3DIS.py#L146

If you want more control on the input features you want to add, this is done here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/datasets/S3DIS.py#L402-L411

GeoSur commented 2 years ago

@HuguesTHOMAS Thank you for your reply and this wonderful work, that is really cool !!! But why the parameter "in_features_dim" should be modified from 5 to 1, in the S3DIS dataset, there are x, y, z, r, g, b, counting to 6 rows. Or the parameter is not correlated with rows of datasets? As I can see from this code it should create KDTree from the ply files, so it may still need an r,g,b rows or features?

HuguesTHOMAS commented 2 years ago

the features do not include x, y, z. For S3DIS they are: 1, r, g, b, h, where 1 is just a column of 1, and h is the height of the points in global coordinates. Please refer to the paper for more info on the features.

Changing to 1 will ignore additional features and just use the column of ones, which is the basic geometric feature.

As I can see from this code it should create KDTree from the ply files, so it may still need an r,g,b rows or features?

The KDTree is only computed on point (x, y, z) for local spherical neighborhoods

GeoSur commented 2 years ago

@HuguesTHOMAS Thanks for your reply! Well the s3dis.py line736

 else:
            print('\nPreparing KDTree for cloud {:s}, subsampled at {:.3f}'.format(cloud_name, dl))

            # Read ply file
            data = read_ply(file_path)
            points = np.vstack((data['x'], data['y'], data['z'])).T
            colors = np.vstack((data['red'], data['green'], data['blue'])).T
            labels = data['class']

It still process the rgb features. And there seems not having an if statement.

GeoSur commented 2 years ago

@HuguesTHOMAS Hello MrTHOMAS!It is me!After these days of testing, I found out what might be causing the problem, please see these pictures. Trainer metrics results The three pictures shows that the "train", "metrics" and the "print message"  respectively. It seems that the pred = np.squeeze(pred) is causing the error of 0D array. So I added an if statement in "trainer" as the picture shows below. 2022-04-01 19-32-54屏幕截图 Now the code is working properly. I want to know the consequences of my actions or is this correct?What impact will it have on the training process?

GeoSur commented 2 years ago
  1. The other question is about the dataset which only has X Y Z label values. I had modified all the lines about "color", and the code is working properly too, so is this correct?

  2. One training process was interrupted by accident, and I restarted it by modified the 'previous path'. which should be selected when testing.

GeoSur commented 2 years ago

@SC-shendazt Hello!have you tried to split the validation datasets?And do you know how to predict the dataset without the “labels”?Example:NPM3D

HuguesTHOMAS commented 2 years ago

Hi @GeoSur,

Thanks for your update, indeed it seems the squeeze function can be dangerous in case an array is only containing 1 element.

I think a better solution instead of fixing the mistake here, would be to prevent any input that only contains 1 point to ever be processed by the network. I just pushed a little fix here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/5b5641e02daac0043adfe97724de8c771dd4772f/datasets/S3DIS.py#L344-L351

Could you tell me if that works?

The other question is about the dataset which only has X Y Z label values. I had modified all the lines about "color", and the code is working properly too, so is this correct?

Yes that should work, you can basically put empty arrays (or all equal to zeros) in the colors, because they will not be used when self.config.in_features_dim == 1

One training process was interrupted by accident, and I restarted it by modified the 'previous path'. which should be selected when testing.

I am not 100% sure that the behavior when using 'previous path' is the same as if you had continued the training without interruption. In my case, I would usually restart the training from the beginning as it is not so long.

GeoSur commented 2 years ago

@HuguesTHOMAS I am really happy I could help you to update this code! And I have a new question about the NPM3D dataset that was processed on the tensorflow version. The test data of this dataset has no true labels. So it doesn't seems to work on the torch version semantic segmentation tasks, because this code can only train/validate the data which has the “truth labels”. This is just my personal conclusion, so is this right? How to process the NPM3D dataset like the tensorflow version(Training and validation on the labeled data, and prediction on the unlabeled data)

SC-shendazt commented 2 years ago

@GeoSur sorry,I didn't test the NPM3D

HuguesTHOMAS commented 2 years ago

@GeoSur,

It is possible to process the data of NPM3D with some minor modifications. First, you can choose one of the training pointclouds as validation data like I do in the tensorflow version. And you can create a test set with the unlabeled point clouds as I do for semanticKitti: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/5b5641e02daac0043adfe97724de8c771dd4772f/datasets/SemanticKitti.py#L73-L81

Following the rest of the SemanticKitti dataset class, when loading data, if the set is 'test' then you just don't load semantic labels from files, you can just make some dummy ones: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/5b5641e02daac0043adfe97724de8c771dd4772f/datasets/SemanticKitti.py#L285-L292

Just modify this line: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/5b5641e02daac0043adfe97724de8c771dd4772f/datasets/S3DIS.py#L761

Also remember you cannot use the "balance class strategy" if there is no class, so you have to use the potential-based sampling (use_potentials=True when creating the dataset).

Then when the network is trained you can use the test_models.py script to test it on the unlabeled data

N.B. Note that you can also train on all the labeled pointclouds, to have the best training set when submitting your results online, but in that case, you do not have validation results to study which are the best parameter values. You can usually do this after you have studied everything well and you are sure of your best parameters, to see if you can get a little margin of improvement with more training data

HuguesTHOMAS commented 2 years ago

@GeoSur,

If you create a NPM3D.py dataset class and everything works well, I would be happy to review it and add it to the repository.

Best, Hugues

GeoSur commented 2 years ago

@HuguesTHOMAS , Thanks for your tips! I will try to create the NPM3D.py class...

 # Safe check for empty spheres 
 if n < 2: 
     failed_attempts += 1 
     if failed_attempts > 2 * self.config.batch_num: 
         raise ValueError('It seems this dataset only containes empty input spheres') 
     t += [time.time()] 
     t += [time.time()] 
     continue

these lines worked. But they also interrupted the training process. It raised error and interrupted this code, I think it should simply skip those empty spheres and continue to the train/validation process. fix

HuguesTHOMAS commented 2 years ago

I just pushed a correction where the number of failed attempts can now go up to 100 * batch_num before returning an error.

I have to keep it just as a safe way to avoid getting into an infinite while loop.

Now it should be better, but just note that all these input spheres with just one point show that your dataset surely has a lot of isolated points that probably should be filtered out.

GeoSur commented 2 years ago

but just note that all these input spheres with just one point show that your dataset surely has a lot of isolated points that probably should be filtered out.

yes,I noticed it,I will try to fix this later

GeoSur commented 2 years ago

@HuguesTHOMAS Hello! I created a dataset.py(actually just modified a few lines), and the train/validation tasks worked well. It a really high mIou on the validation and test process. 22BC0B13D092F1C59AB688926F66C52D But when I change the set to 'test' splits, the test_model.py worked not well, it producted '0' class, which should not appear( ignore class [0]). I assume that is because the test process is breaked by # Break when reaching number of desired votes if last_min > num_votes break so, the test process did not have enough time to predict all of the points, and the potential is not correct on test data too. 0E1DFDF903FA0730FA8A91DB8A7232C9 so I just paused the if statement. Could you please tell me what happened when the set == 'test' and why these are so many points are predicted to '0'? Is it really because of the break instruction?

SC-shendazt commented 2 years ago

Hello Dr. HuguesTHOMAS, if I want to stack the 1024 dimensional features of the last layer after each layer of the upsampling, how should I modify the UnaryBlock in the block_decider module. It's hard for me as a noob to modify it because your code is so powerful. I want to change the upsampling to a dense connection (denseKPconv). Hope to get your guidance 175711156583109

HuguesTHOMAS commented 2 years ago

Hi @GeoSur,

It is probably not the break statement because it stops when the minimum of potential reaches a certain value, meaning every point in the test set has been processed at least N times, where N is equal to num_votes. You can verify that by having a look at the saved potentials.

I think I found the problem actually. It should come from the fact that in case the set is 'test', the code did not insert a column of predictions equal to zero for the ignored labels. I just pushed a correction and it should work well now. Can you confirm?

HuguesTHOMAS commented 2 years ago

Hi @SC-shendazt,

Ok so if I understand correctly, you want to stack the 1024 dimensional features of the last layer after each layer of the upsampling and before the unary blocks, which means your unary block would have as input a concatenation of:

Here is what you need to do:

  1. You need to modify the dimension of the input features for the Unary blocks. You can see here I already do a similar thing, by adding the dimension of the skip connection features: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/4cebd543d796ce3c324cb9f39486deab97133bfe/models/architectures.py#L273-L275 So it should be pretty easy for you to add the dimension of the last features too
  2. You need to modify the forward function where the concatenation actually happens. Similarly you can use my code, where I concatenate the skip connection features, and add your own features too: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/4cebd543d796ce3c324cb9f39486deab97133bfe/models/architectures.py#L335-L336

The last thing I did not mention is that the 1024 dimensional features of the last layer are only defined for the points of the last layer. So you will need to choose how you project them on the points of the other layers. I see two ways to do that:

  1. You use the upsampling blocks to project them layer by layer. YOu would need to modify the decoder forward function: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/4cebd543d796ce3c324cb9f39486deab97133bfe/models/architectures.py#L334-L337
  2. More simple you average the last layer feature to only have one, and then you just concatenate this single feature to every points at each layer.
GeoSur commented 2 years ago

Hi!@HuguesTHOMAS Thanks a lot for your reply! The modified version you update works well! And I submit the result online. Here is the result and confusion matrix.

截屏2022-04-12 14 31 52

So do you want me to update the NPM3D.py file to expand the repository?

HuguesTHOMAS commented 2 years ago

Sure! can you send me your dataset file by email and I will update the repo!

Thank you for your help and involvement in this repo!

GeoSur commented 2 years ago

Hi !@HuguesTHOMAS I have sent an email to your Gmail address!

HuguesTHOMAS commented 2 years ago

Thanks a lot, I just updated the repository.

Could you try the train_NPM3D.py script to see if it works?

GeoSur commented 2 years ago

@HuguesTHOMAS You are really efficient! I have tried the train_NPM3D.py and it works well on the RTX 2070(8G GPU Memory). Thanks for your wonderful work! It really helps me a lot! And I am really happy to make a small contribution..