pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.91k stars 3.61k forks source link

Pointnet-classification task implementation #1551

Open shangguan9191 opened 4 years ago

shangguan9191 commented 4 years ago

Dear author,

i am wondering if you could add pointnet model to your example, i noticed that there is only pointnet++ examples in your repo. i have already reimplemented pointnet for my own dataset in other repo, however, the total training time is so long. Therefore, i am really appreciated i could compare running time of similar methods in same repo for better results interpretation.

rusty1s commented 4 years ago

Hi and thanks for this issue. We currently do not have a PointNet example since this is not a relational method. However, it can be easily implemented, similar to this:

class PointNet(torch.nn.Module):
    def __init__(self, ...):
        ...
        self.mlp1 = ...
        self.mlp2 = ...

    def forward(self, pos, batch):
        x = self.mlp1(pos)
        x = global_max_pool(x, batch)
        return self.mlp2(x)
shangguan9191 commented 4 years ago

I notice in your benchmark for point, the forward is the same, if it is possible to write in this way for pointnet?

class Net(torch.nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()

        nn = Seq(Lin(3, 64), ReLU(), Lin(64, 64))
        self.conv1 = PointConv(local_nn=nn)

        nn = Seq(Lin(67, 128), ReLU(), Lin(128, 128))
        self.conv2 = PointConv(local_nn=nn)

        nn = Seq(Lin(131, 256), ReLU(), Lin(256, 256))
        self.conv3 = PointConv(local_nn=nn)

        self.lin1 = Lin(256, 256)
        self.lin2 = Lin(256, 256)
        self.lin3 = Lin(256, num_classes)
def forward(self, pos, batch):
        radius = 0.2
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv1(None, pos, edge_index))

        idx = fps(pos, batch, ratio=0.5)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 0.4
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv2(x, pos, edge_index))

        idx = fps(pos, batch, ratio=0.25)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 1
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv3(x, pos, edge_index))

        x = global_max_pool(x, batch)

        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
rusty1s commented 4 years ago

Yes, you can simply swap out Net with the above PointNet example.

shangguan9191 commented 4 years ago

Could you please consider to write the code for it, if it does not take much of your time. For people who are interested in comparison of the algorithm, it would be easier to reimplement in this way. All i could do is reshape the data, however, code for layer and forward is tricky for me.