valeoai / LightConvPoint

Other
64 stars 15 forks source link

implement your model for cloudpoint #4

Closed shangguan9191 closed 4 years ago

shangguan9191 commented 4 years ago

Dear Author,

I have 3D point Cloud dataset of form: Data(pos=[10240, 3], y=[1]) . I am really confused of the input_pts.

what would be the best way to implement your model?

transfrom my data? rewrite the achiecture and forward pass, and training loop?

aboulch commented 4 years ago

Hello, The input_pts are the point coordinates of input point cloud => pos in your case. The format is [Batch size, Dim, NPoints] The features are inputs_fts or x, depending on the example => y in your case.

shangguan9191 commented 4 years ago

in your model for classification task, there is def forward(self, x, input_pts):.

I write my train model like `def train(epoch): model.train()

for data in train_loader:
    data = data.to(device)
    optimizer.zero_grad()
    loss = F.nll_loss(model(data.x, data.pos), data.y)
    loss.backward()
    optimizer.step()`

i simply injected your network for classification for my model, However, it did not work. i am wondering if there is no data.x in Data(pos=[10240, 3], y=[1]) .

What would you suggest a proper way to implement ConvPoint classification for your own dataset like I mentioned.

aboulch commented 4 years ago

You need to reshape the data to fit input_pts.shape = [BatchSize, Dim, NPoints] and x.shape = [BatchSize, C, NPoints] For example with a batchsize=1 and x=1 you can do:

npoints = data.pos.shape[1]
x = torch.ones(1, 1, npoints ).to(device)
input_pts = data.pos.transpose(0,1).unsqueeze(0)

then:

outputs = model(x, input_pts)
shangguan9191 commented 4 years ago

(https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet2_classification.py))

Isnt his example already reshape the data? ` i am really new to this. could you give me more specific instruction for reshaping data in which steps? Loading data in trainloader, Forward pass, trainning loop?

aboulch commented 4 years ago

Hello, I am not familiar with pytorch_geometric. I do not about the data format.