qinglew / PCN-PyTorch

Implementation of PCN(Point Completion Network) in PyTorch.
141 stars 33 forks source link

torch.max fn - dimension #3

Closed aleenajohnsony closed 3 years ago

aleenajohnsony commented 3 years ago

This is from encoder. Why is the dimension 2? I couldnot find a valid dimension. def forward(self, x): n = x.size()[2]

    # first shared mlp
    x = F.relu(self.bn1(self.conv1(x)))           # (B, 128, N)
    f = self.bn2(self.conv2(x))                   # (B, 256, N)

    # point-wise maxpool
    g = torch.max(f, **dim=2**, keepdim=True)[0]      # (B, 256, 1)

    # expand and concat
    x = torch.cat([g.repeat(1, 1, n), f], dim=1)  # (B, 512, N)

    # second shared mlp
    x = F.relu(self.bn3(self.conv3(x)))           # (B, 512, N)
    x = self.bn4(self.conv4(x))                   # (B, 1024, N)

    # point-wise maxpool
    v = torch.max(x, dim=-1)[0]                   # (B, 1024)

    return v
qinglew commented 3 years ago

Note that the first max pooling use torch.max(f, dim=2, keepdim=True)[0], the keyword argument keepdim is set True, it means that we keep the dimension, so the output dimension of the first max pooling is (B, 256, 1). In the second max pooling, I didn't set the keepdim to True, because I just want every point cloud to be a latent code, i.e. a vector. So, in a mini-batch data, every point cloud is represented a vector. The final dimension of a batch of point cloud is (B, 1024).