WangYueFt / dgcnn

MIT License
1.62k stars 420 forks source link

Understanding equation #8 in research paper (Section 3.1) #74

Open akashaero opened 3 years ago

akashaero commented 3 years ago

So I am going through your paper and in Section 3.1 Edge Convolution, equation #8 suggests some learnable weights theta and phi.

Looking at the graph feature function, it does not use any weights when constructing a graph (feature = concat(xi - xj, xi)). Is that right or subsequent convolutional layers in partseg or semseg models allow learning of weights for this graph feature?

By the way, I am using this implementation of your network in pytorch

i.e

# Graph feature construction
def get_graph_feature(x, k=20, idx=None, dim9=False):
    '''
    Calculates feature graph which is "concat(xj - xi, xi)"
    Incoming (with 3 features)
        x --> torch tensor ([batch_size, 3, num_points])
    Outgoing
        feature --> torch tensor ([])

    '''
    batch_size = x.size(0)
    num_points = x.size(2)

    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 == False:
            idx = knn(x, k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base   # idx -> (batch_size, num_points, k) , idx_base -> (batch_size,1,1)

    idx = idx.view(-1)   # (batch_size, num_points, k) --> (batch_size*num_points*k)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims)
    feature = x.view(batch_size*num_points, -1)[idx, :]  # (batch_size*num_points, num_dims) -> (batch_size*num_points*k, num_dims)
    feature = feature.view(batch_size, num_points, k, num_dims) # (batch_size*num_points*k, num_dims) -> (batch_size, num_points, k, num_dims)

    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

   ####### NO LEARNABLE WEIGHTS HERE ########################
    **feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()**
   ####### NO LEARNABLE WEIGHTS HERE ########################

    return feature      # (batch_size, 2*num_dims, num_points, k)

# part seg network
x = get_graph_feature(x, k=self.k)
x = self.conv1(x)
x = self.conv2(x)
x1 = x.max(dim=-1, keepdim=False)[0]