Open sayakgis opened 4 years ago
We do not have a PointCNN segmentation example, but it should be straightforward to implement by yourself by combining the PointCNN
example and the segmentation code.
Where would num_features be applied on the point cnn example?
num_features
would correspond to the input feature dimension of points. As such, you can do
self.conv1 = XConv(dataset.num_features, 48, dim=3, kernel_size=8, hidden_channels=32)
Thanks for the pointer @rusty1s. The following still yields an error in validation.
class Net(torch.nn.Module):
def __init__(self, num_classes, num_features=3):
super().__init__()
self.conv1 = XConv(num_features, 48, dim=3, kernel_size=8, hidden_channels=32)
self.conv2 = XConv(48, 96, dim=3, kernel_size=12, hidden_channels=64, dilation=2)
self.conv3 = XConv(96, 192, dim=3, kernel_size=16, hidden_channels=128, dilation=2)
self.conv4 = XConv(192, 384, dim=3, kernel_size=16, hidden_channels=256, dilation=2)
self.lin1 = Lin(384, 256)
self.lin2 = Lin(256, 128)
self.lin3 = Lin(128, num_classes)
def forward(self, data):
x, pos, batch = data.x, data.pos, data.batch
x = F.relu(self.conv1(x, pos, batch))
idx = fps(pos, batch, ratio=0.375)
x, pos, batch = x[idx], pos[idx], batch[idx]
x = F.relu(self.conv2(x, pos, batch))
idx = fps(pos, batch, ratio=0.334)
x, pos, batch = x[idx], pos[idx], batch[idx]
x = F.relu(self.conv3(x, pos, batch))
x = F.relu(self.conv4(x, pos, batch))
x = global_mean_pool(x, batch)
x = F.relu(self.lin1(x))
x = F.relu(self.lin2(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin3(x)
return F.log_softmax(x, dim=-1)
Error output...
Exception has occurred: ValueError
Expected input batch_size (2) to match target batch_size (4096).
Right about here
def validation_step(self, data, batch_idx):
logits = self.forward(data) # Forward pass.
loss = self.loss(logits, data.y, ignore_index=0) # Loss com
...
Any thoughts?? Thanks in advance for the support and excellent lib!
How big are your point clouds? It looks like there may be a problem that we cannot select kernel_size
neighbors after coarsening in each example. The following works for me:
import torch
import torch.nn.functional as F
from torch_cluster import fps
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import XConv, global_mean_pool
class Net(torch.nn.Module):
def __init__(self, num_classes, num_features):
super().__init__()
self.conv1 = XConv(num_features, 48, dim=3, kernel_size=8,
hidden_channels=32)
self.conv2 = XConv(48, 96, dim=3, kernel_size=12, hidden_channels=64,
dilation=2)
self.conv3 = XConv(96, 192, dim=3, kernel_size=16, hidden_channels=128,
dilation=2)
self.conv4 = XConv(192, 384, dim=3, kernel_size=16,
hidden_channels=256, dilation=2)
self.lin1 = torch.nn.Linear(384, 256)
self.lin2 = torch.nn.Linear(256, 128)
self.lin3 = torch.nn.Linear(128, num_classes)
def forward(self, data):
x, pos, batch = data.x, data.pos, data.batch
x = F.relu(self.conv1(x, pos, batch))
idx = fps(pos, batch, ratio=0.375)
x, pos, batch = x[idx], pos[idx], batch[idx]
x = F.relu(self.conv2(x, pos, batch))
idx = fps(pos, batch, ratio=0.334)
x, pos, batch = x[idx], pos[idx], batch[idx]
x = F.relu(self.conv3(x, pos, batch))
x = F.relu(self.conv4(x, pos, batch))
x = global_mean_pool(x, batch)
x = F.relu(self.lin1(x))
x = F.relu(self.lin2(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin3(x)
return F.log_softmax(x, dim=-1)
data = Data(pos=torch.randn(1000, 3), x=torch.randn(1000, 16))
data_list = [data, data, data]
loader = DataLoader(data_list, batch_size=3)
batch = next(iter(loader))
model = Net(num_classes=10, num_features=16)
out = model(batch)
print(out.shape)
hmm... It's pretty small for an ALS dataset. Maybe around 6.5 million points. The counts are definitely changing from start to end of the forward pass.
After the first line @ x, pos, batch = data.x, data.pos, data.batch
the tensor shapes look like this...
x.shape = torch.Size([32767, 10])
pos.shape = torch.Size([32767, 3])
batch.shape = torch.Size([32767])
after x = self.lin3(x)
the tensor shapes look like this...
x.shape = torch.Size([8, 3])
pos.shape = torch.Size([4112, 3])
batch.shape = torch.Size([4112])
In this case the error is ValueError: Expected input batch_size (8) to match target batch_size (32767).
when we reach loss = self.loss(logits, data.y, ignore_index=0)
as data.y shape is that of the original tensor torch.Size([32767])
.
Ok, but your model is performing a global pooling step, while you are training against node-level targets. You should consider dropping the fps
calls and the global_mean_pool
operator. The current architecture is for graph classification/regression tasks.
Perfect! This seems to get me to what I'm looking for. Thanks for the assist.
import torch
import torch.nn.functional as F
from torch_geometric.nn import XConv
class Net(torch.nn.Module):
def __init__(self, num_classes, num_features=3):
super().__init__()
self.conv1 = XConv(num_features, 48, dim=3, kernel_size=8, hidden_channels=32)
self.conv2 = XConv(48, 96, dim=3, kernel_size=12, hidden_channels=64, dilation=2)
self.conv3 = XConv(96, 192, dim=3, kernel_size=16, hidden_channels=128, dilation=2)
self.conv4 = XConv(192, 384, dim=3, kernel_size=16, hidden_channels=256, dilation=2)
self.lin1 = torch.nn.Linear(384, 256)
self.lin2 = torch.nn.Linear(256, 128)
self.lin3 = torch.nn.Linear(128, num_classes)
def forward(self, data):
x, pos, batch = data.x, data.pos, data.batch
x = F.relu(self.conv1(x, pos, batch))
x = F.relu(self.conv2(x, pos, batch))
x = F.relu(self.conv3(x, pos, batch))
x = F.relu(self.conv4(x, pos, batch))
x = F.relu(self.lin1(x))
x = F.relu(self.lin2(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin3(x)
return F.log_softmax(x, dim=-1)
❓ PointCNN segmentation example
Hello Matthias,
Thanks for putting all the effort in PyG, this is becoming fabulous everyday.
I was trying with PointCNN segmentation cases, could you please point me any example codes for this.
I am currently using the original tensorflow repo but wanted to migrate to pytorch now.
Thanks for the help in advance.
Sayak