Open shangguan9191 opened 4 years ago
Hi and thanks for this issue. We currently do not have a PointNet
example since this is not a relational method. However, it can be easily implemented, similar to this:
class PointNet(torch.nn.Module):
def __init__(self, ...):
...
self.mlp1 = ...
self.mlp2 = ...
def forward(self, pos, batch):
x = self.mlp1(pos)
x = global_max_pool(x, batch)
return self.mlp2(x)
I notice in your benchmark for point, the forward is the same, if it is possible to write in this way for pointnet?
class Net(torch.nn.Module):
def __init__(self, num_classes):
super(Net, self).__init__()
nn = Seq(Lin(3, 64), ReLU(), Lin(64, 64))
self.conv1 = PointConv(local_nn=nn)
nn = Seq(Lin(67, 128), ReLU(), Lin(128, 128))
self.conv2 = PointConv(local_nn=nn)
nn = Seq(Lin(131, 256), ReLU(), Lin(256, 256))
self.conv3 = PointConv(local_nn=nn)
self.lin1 = Lin(256, 256)
self.lin2 = Lin(256, 256)
self.lin3 = Lin(256, num_classes)
def forward(self, pos, batch):
radius = 0.2
edge_index = radius_graph(pos, r=radius, batch=batch)
x = F.relu(self.conv1(None, pos, edge_index))
idx = fps(pos, batch, ratio=0.5)
x, pos, batch = x[idx], pos[idx], batch[idx]
radius = 0.4
edge_index = radius_graph(pos, r=radius, batch=batch)
x = F.relu(self.conv2(x, pos, edge_index))
idx = fps(pos, batch, ratio=0.25)
x, pos, batch = x[idx], pos[idx], batch[idx]
radius = 1
edge_index = radius_graph(pos, r=radius, batch=batch)
x = F.relu(self.conv3(x, pos, edge_index))
x = global_max_pool(x, batch)
x = F.relu(self.lin1(x))
x = F.relu(self.lin2(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin3(x)
return F.log_softmax(x, dim=-1)
Yes, you can simply swap out Net
with the above PointNet
example.
Could you please consider to write the code for it, if it does not take much of your time. For people who are interested in comparison of the algorithm, it would be easier to reimplement in this way. All i could do is reshape the data, however, code for layer and forward is tricky for me.
Dear author,
i am wondering if you could add pointnet model to your example, i noticed that there is only pointnet++ examples in your repo. i have already reimplemented pointnet for my own dataset in other repo, however, the total training time is so long. Therefore, i am really appreciated i could compare running time of similar methods in same repo for better results interpretation.