lucidrains / point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
MIT License
592 stars 58 forks source link

Issues with my wrapper code #11

Closed StellaAthena closed 3 years ago

StellaAthena commented 3 years ago

I wrote some wrapper code to turn this layer into a full transformer and I can't seem to figure out what is going wrong. The following works:

import torch
from torch import nn, einsum
import x_transformers
from point_transformer_pytorch import PointTransformerLayer

layer = PointTransformerLayer(
    dim = 7,
    pos_mlp_hidden_dim = 64,
    attn_mlp_hidden_mult = 4,
    num_neighbors = 16          # only the 16 nearest neighbors would be attended to for each point
)

feats = torch.randn(1, 5, 7)
pos = torch.randn(1, 5, 3)
mask = torch.ones(1, 5).bool()

y = layer(feats, pos, mask = mask)

However this doesn't work

import torch
from torch import nn, einsum
import x_transformers
from point_transformer_pytorch import PointTransformerLayer

class PointTransformer(nn.Module):
    def __init__(self, feats, mask, neighbors = 16, layers=5, dimension=5):

        super().__init__()

        self.feats = feats
        self.mask = mask
        self.neighbors = neighbors

        self.layers = []

        for _ in range(layers):
            self.layers.append(PointTransformerLayer(
                dim = dimension,
                pos_mlp_hidden_dim = 64,
                attn_mlp_hidden_mult = 4,
                num_neighbors = self.neighbors
            ))

    def forward(self, pos):
        curr_pos = pos
        for layer in self.layers:
            print(curr_pos)
            curr_pos = layer(self.feats, pos, self.mask)
            print("----")
        return curr_pos

model = PointTransformer(feats, mask)
model(pos)

The error I'm getting is mat1 and mat2 shapes cannot be multiplied (5x7 and 5x15)

StellaAthena commented 3 years ago

NVM I figured it out