Closed StellaAthena closed 3 years ago
I wrote some wrapper code to turn this layer into a full transformer and I can't seem to figure out what is going wrong. The following works:
import torch from torch import nn, einsum import x_transformers from point_transformer_pytorch import PointTransformerLayer layer = PointTransformerLayer( dim = 7, pos_mlp_hidden_dim = 64, attn_mlp_hidden_mult = 4, num_neighbors = 16 # only the 16 nearest neighbors would be attended to for each point ) feats = torch.randn(1, 5, 7) pos = torch.randn(1, 5, 3) mask = torch.ones(1, 5).bool() y = layer(feats, pos, mask = mask)
However this doesn't work
import torch from torch import nn, einsum import x_transformers from point_transformer_pytorch import PointTransformerLayer class PointTransformer(nn.Module): def __init__(self, feats, mask, neighbors = 16, layers=5, dimension=5): super().__init__() self.feats = feats self.mask = mask self.neighbors = neighbors self.layers = [] for _ in range(layers): self.layers.append(PointTransformerLayer( dim = dimension, pos_mlp_hidden_dim = 64, attn_mlp_hidden_mult = 4, num_neighbors = self.neighbors )) def forward(self, pos): curr_pos = pos for layer in self.layers: print(curr_pos) curr_pos = layer(self.feats, pos, self.mask) print("----") return curr_pos model = PointTransformer(feats, mask) model(pos)
The error I'm getting is mat1 and mat2 shapes cannot be multiplied (5x7 and 5x15)
mat1 and mat2 shapes cannot be multiplied (5x7 and 5x15)
NVM I figured it out
I wrote some wrapper code to turn this layer into a full transformer and I can't seem to figure out what is going wrong. The following works:
However this doesn't work
The error I'm getting is
mat1 and mat2 shapes cannot be multiplied (5x7 and 5x15)