kentechx / pointnext

Pytorch implementation of PointNeXt
MIT License
8 stars 0 forks source link

Question regarding X and XYZ #2

Open abrahamezzeddine opened 8 months ago

abrahamezzeddine commented 8 months ago

Hello,

What does X correspond to if we input XYZ coordinates? It is not entirely clear what X will achieve.

Thank you.

kentechx commented 8 months ago

X is the feature of points, consisting of XYZ, normals, colors, etc., whereas XYZ coordinates are used for sampling.

abrahamezzeddine commented 8 months ago

OK, thank you so much!

I assume that I have to make sure the data, for example ModelNet40 is uniformly downsampled to 1024 points before I input it to the model, otherwise I will have an error.

Question - then I assume that the neural network will also perform its own FPS to downsample the data? Once before I feed it the ModelNet40 and once in the neural network path? It seems a bit redudant to do that?

How can I send it ModelNet40 and let the model itself downsample it accordingly?

Thank you.

Btw, I am new to coding so I am trying to learn a lot of course. :) It seems that I had to adapt the code a bit (tensor data ptr had to be used or something like that) and create a new file too that would wrap them with PYBIND11_MODULE, otherwise it could complain a lot.

include <pybind11/pybind11.h>

include <torch/extension.h>

// Declarations of functions from different source files extern void ball_query_wrapper_fast(int64_t b, int64_t n, int64_t m, double radius, int64_t nsample, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);

extern void group_points_grad_wrapper_fast(int64_t b, int64_t c, int64_t n, int64_t npoints, int64_t nsample, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); extern void group_points_wrapper_fast(int64_t b, int64_t c, int64_t n, int64_t npoints, int64_t nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);

extern void three_nn_wrapper_fast(int64_t b, int64_t n, int64_t m, at::Tensor unknown_tensor, at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); extern void three_interpolate_wrapper_fast(int64_t b, int64_t c, int64_t m, int64_t n, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); extern void three_interpolate_grad_wrapper_fast(int64_t b, int64_t c, int64_t n, int64_t m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);

extern void gather_points_wrapper_fast(int64_t b, int64_t c, int64_t n, int64_t npoints, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); extern void gather_points_grad_wrapper_fast(int64_t b, int64_t c, int64_t n, int64_t npoints, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); extern void furthest_point_sampling_wrapper(int64_t b, int64_t n, int64_t m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ball_query_wrapper_fast", &ball_query_wrapper_fast, "Fast ball query wrapper"); m.def("group_points_grad_wrapper_fast", &group_points_grad_wrapper_fast, "Fast group points gradient wrapper"); m.def("group_points_wrapper_fast", &group_points_wrapper_fast, "Fast group points wrapper");

m.def("three_nn_wrapper_fast", &three_nn_wrapper_fast, "Three nearest neighbors wrapper fast");
m.def("three_interpolate_wrapper_fast", &three_interpolate_wrapper_fast, "Three interpolate wrapper fast");
m.def("three_interpolate_grad_wrapper_fast", &three_interpolate_grad_wrapper_fast, "Three interpolate gradient wrapper fast");

m.def("gather_points_wrapper_fast", &gather_points_wrapper_fast, "Gather points wrapper fast");
m.def("gather_points_grad_wrapper_fast", &gather_points_grad_wrapper_fast, "Gather points gradient wrapper fast");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "Furthest point sampling wrapper");

}

I am using latest pytorch and CUDA 12.1 for this.

kentechx commented 8 months ago

Downsampling ModelNet40 data to 1024 points should be part of the data preparation process, a step not covered in this repository. The network will downsample the input data four times in one run, with a default stride of 4. For example, input with 1024 points will be successively downsampled to 256, 64, 16, and 4 points.

Regarding the compiling issue, there's no need for additional files. Have you installed the package using pip install pointnexst? Alternatively, you can build the CUDA ops by executing pip install . -e after cloning the repository.

abrahamezzeddine commented 8 months ago

Hello kentechx,

That is correct. That is what I did; pip install pointnext but it was complaining about some tensor data being decrepitated so I had to replace the tensor data to tensor data ptr something.

When I then changed it to the expected tensor data ptr, it was then complaining about wrapper error. I had to create the wrapper file and then it worked just fine. I am using Windows. In case you need the files for review, I can attach them of course. :)

kentechx commented 8 months ago

Hello kentechx,

That is correct. That is what I did; pip install pointnext but it was complaining about some tensor data being decrepitated so I had to replace the tensor data to tensor data ptr something.

When I then changed it to the expected tensor data ptr, it was then complaining about wrapper error. I had to create the wrapper file and then it worked just fine. I am using Windows. In case you need the files for review, I can attach them of course. :)

I haven't tested it in Windows yet. And, sure, you can make a PR for the files. Thanks!

abrahamezzeddine commented 8 months ago

Ok, will do it soon.

Second question - if I have, for example several .xyz file that contains XYZ, RGB and normals, do I split up the content so I feed as follows?

XYZ - XYZ coordinates X - RGB + Normals (I guess I do not feed XYZ here too)?

For labeling, would I also be feeding categories into X?

Thank you!

kentechx commented 8 months ago

You should feed xyz into x. Otherwise, the network will not consume xyz as the input feature. Here's an example.

import torch
from pointnext import PointNext, pointnext_s

encoder = pointnext_s(in_dim=6)
model = PointNext(40, encoder=encoder).cuda()

xyz = torch.randn(2, 3, 1024).cuda()
rgb = torch.rand(2, 3, 1024).cuda()
x = torch.cat([xyz, rgb], dim=1)
out = model(x, xyz)

Don't feed categories into x when training. Categories are only needed in part segmentation. In part segmentation, input the category as the 3rt parameter, as shown in the code example in the README.

gabrielconstantin02 commented 5 months ago

Hey,

First of all thanks for the great implementation ! Hopping on the back of this question. In the above example, I guess 2 represents the batch_size, 40 the number of classes to be predicted, then the output would be [2, 40, 4].

It is unclear to me how this can be used then for classification with 40 classes. Does one need to a apply the a Linear layer to get [2, 40] or am I missing something ?

Sorry for the layman's question, I'm new to this area of graph neural networks.

kentechx commented 5 months ago

Hey,

First of all thanks for the great implementation ! Hopping on the back of this question. In the above example, I guess 2 represents the batch_size, 40 the number of classes to be predicted, then the output would be [2, 40, 4].

It is unclear to me how this can be used then for classification with 40 classes. Does one need to a apply the a Linear layer to get [2, 40] or am I missing something ?

Sorry for the layman's question, I'm new to this area of graph neural networks.

Oh, it's a mistake. For classification, PointNext(40, encoder=encoder) only returns the backbone (for downsampling the input point cloud to a lower resolution and extracting features). You will need to write the head yourself. Here's an example:

import torch
import torch.nn as nn
from pointnext import PointNext, pointnext_s

class Model(nn.Module):

    def __init__(self, in_dim=6, out_dim=40, dropout=0.5):
        super().__init__()

        encoder = pointnext_s(in_dim=in_dim)
        self.backbone = PointNext(1024, encoder=encoder)

        self.norm = nn.BatchNorm1d(1024)
        self.act = nn.ReLU()
        self.cls_head = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, out_dim),
        )

    def forward(self, x, xyz):
        out = self.norm(self.backbone(x, xyz))
        out = out.mean(dim=-1)
        out = self.act(out)
        out = self.cls_head(out)
        return out

model = Model(in_dim=6, out_dim=40).cuda()
xyz = torch.randn(2, 3, 1024).cuda()
rgb = torch.rand(2, 3, 1024).cuda()
x = torch.cat([xyz, rgb], dim=1)
out = model(x, xyz)
print(out.shape)  # (2, 40)
gabrielconstantin02 commented 5 months ago

I see. I had a similar implementation but without the normalization and mean and the model was constantly predicting the same class while hitting NaN loss.

Now it works like a magic and makes much more sense. Thank you so much for taking the time to respond !