lucidrains / En-transformer

Implementation of E(n)-Transformer, which incorporates attention mechanisms into Welling's E(n)-Equivariant Graph Neural Network
MIT License
208 stars 28 forks source link

Incorrect output shape #11

Closed leffff closed 1 year ago

leffff commented 1 year ago
# taken from https://github.com/lucidrains/En-transformer/blob/main/README.md
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)

https://github.com/lucidrains/En-transformer/blob/b09545f01637056ffb69f1c56a3882e102c0093d/README.md?plain=1#L38

As we see the input shape of node features (feats) and coordinates (coors) is [batch_size, num_nodes, num_features]. Num nodes is equal to 1024. however the output num_nodes is equal to 16.

My assumption is that the output shape was by mistake taken from here https://github.com/lucidrains/egnn-pytorch/blob/main/README.md

The varying num_nodes is confusing, can you explain, whether this is just a mistake of num_nodes really changes and why?

lucidrains commented 1 year ago

@leffff hey Lev, sorry the readme had incorrect comments. it should be 1024 like you said

lucidrains commented 1 year ago
import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,                       # depth
    dim_head = 64,                   # dimension per head
    heads = 8,                       # number of heads
    edge_dim = 4,                    # dimension of edge feature
    neighbors = 64,                  # only do attention between coordinates N nearest neighbors - set to 0 to turn off
    talking_heads = True,            # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
    checkpoint = True,               # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
    use_cross_product = True,        # use cross product vectors (idea by @MattMcPartlon)
    num_global_linear_attn_heads = 2 # if your number of neighbors above is low, you can assign a certain number of attention heads to weakly attend globally to all other nodes through linear attention (https://arxiv.org/abs/1812.01243)
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)

mask = torch.ones(1, 1024).bool()

feats, coors = model(feats, coors, edges, mask = mask)  # (1, 1024, 512), (1, 1024, 3)
assert feats.shape == (1, 1024, 512)
assert coors.shape == (1, 1024, 3)
leffff commented 1 year ago

Thanks! Was just checking)

Thanks for your awesome work! It helped me a lot!

lucidrains commented 1 year ago

No problem! What are you training on?

leffff commented 1 year ago

I ma conducting a research in crystal structure generation. I use crystal structures from AFlow

lucidrains commented 1 year ago

@leffff very cool! 😍