Closed leffff closed 1 year ago
@leffff hey Lev, sorry the readme had incorrect comments. it should be 1024 like you said
import torch
from en_transformer import EnTransformer
model = EnTransformer(
dim = 512,
depth = 4, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
edge_dim = 4, # dimension of edge feature
neighbors = 64, # only do attention between coordinates N nearest neighbors - set to 0 to turn off
talking_heads = True, # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
checkpoint = True, # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
use_cross_product = True, # use cross product vectors (idea by @MattMcPartlon)
num_global_linear_attn_heads = 2 # if your number of neighbors above is low, you can assign a certain number of attention heads to weakly attend globally to all other nodes through linear attention (https://arxiv.org/abs/1812.01243)
)
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)
mask = torch.ones(1, 1024).bool()
feats, coors = model(feats, coors, edges, mask = mask) # (1, 1024, 512), (1, 1024, 3)
assert feats.shape == (1, 1024, 512)
assert coors.shape == (1, 1024, 3)
Thanks! Was just checking)
Thanks for your awesome work! It helped me a lot!
No problem! What are you training on?
I ma conducting a research in crystal structure generation. I use crystal structures from AFlow
@leffff very cool! 😍
https://github.com/lucidrains/En-transformer/blob/b09545f01637056ffb69f1c56a3882e102c0093d/README.md?plain=1#L38
As we see the input shape of node features (feats) and coordinates (coors) is [batch_size, num_nodes, num_features]. Num nodes is equal to 1024. however the output num_nodes is equal to 16.
My assumption is that the output shape was by mistake taken from here https://github.com/lucidrains/egnn-pytorch/blob/main/README.md
The varying num_nodes is confusing, can you explain, whether this is just a mistake of num_nodes really changes and why?