HuguesTHOMAS / KPConv-PyTorch

Kernel Point Convolution implemented in PyTorch
MIT License
747 stars 150 forks source link

Costum Dataset regression task #202

Open Pippo809 opened 1 year ago

Pippo809 commented 1 year ago

Good morning, I wanted to use this repo to perform a regression task on a bunch of pointclouds. My dataset then consists of pcd's with their regression score (between 0 and 1). I was thinking of adapting the classification task for this purpose; however I don't understand very well how my dataset has to be organized. I was trying to reverse engineer the ModelNet40 Class but I can't understand very well how the getitem works and how it can be adapted for a dataset with only 1 class (but a regression target).

HuguesTHOMAS commented 1 year ago

Hi @Pippo809,

First of all you will need to change the loss of the network here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/models/architectures.py#L151-L171

You can adapt this function for a regression task. I am not aware of the best ways to do the regression with a deep network, but in any case, I assume the labels tensor would be a float32 tensor with values between 0 and 1, instead of an int32 tensor with classes.

Therefore from there you can reverse engineer. The loss function is called here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/utils/trainer.py#L187-L190

and therefore the labels you need to modify are the batch.labels which are created here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/datasets/ModelNet40.py#L681-L704

where the input list is the one provided by the __getitem__ function. If you track the labels of this list to their origin, they are defined here:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/datasets/ModelNet40.py#L175-L178

So from there, there are two things you need to do:

  1. Modify self.input_labels so that it contains your regression values instead of classes.
  2. Verify that all along the way to the loss function, you remove any .astype(np.int32) or similar functions that is applied on the labels, to keep the float values intact.
Pippo809 commented 1 year ago

Perfect! Thank you very much for the quick response. I'll see how it goes