FabianFuchsML / se3-transformer-public

code for the SE3 Transformers paper: https://arxiv.org/abs/2006.10503
475 stars 69 forks source link

add a toy example #1

Closed Chen-Cai-OSU closed 3 years ago

Chen-Cai-OSU commented 3 years ago

Hello,

Thank you for the nice paper. I would like to try the se3-transformer on my dataset. However, I am a PyTorch geometric user and not familiar with DGL. Would it be possible to provide a very simple example?

Right now, in the basic usage section of README, the format of G is not specified. I guess since se3-transformer deals with different types of features, they should be put under different keys. Again, a simple example would be easy for new users to try out.

I printed out the one QM9 graph in dgl graph format

(DGLGraph(num_nodes=23, num_edges=506,
         ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'f': Scheme(shape=(6, 1), dtype=torch.float32)}
         edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(5,), dtype=torch.float32)}), tensor([[-0.2169]]))

should i organize my feature in the similar format? Thank you!

FabianFuchsML commented 3 years ago

Hi Chen,

Happy to hear that you liked our paper!

That's good feedback, it is on my todo list to make this part more intuitive. Let me try below by using the qm9 example.

The most important bit is that you save the relative distances as edge feature 'd' - those are used to evaluate the spherical harmonics (in a function called get_basis_and_r(G, max_degree)) and this is what makes the layers equivariant.

'f' in the above is a (scalar) node feature - this is handled in the model class (qm9/models.py). You could replace that with a higher degree feature which rotates (like a velocity) - you would just have to change the input fibre accordingly in the model definition. You can also just feed in ones, but you do have to feed in something.

'w' is optional and is a scalar edge feature. If you don't have any edge features, just don't specify it. However, you will need to make a small adjustment in equivariant_attention/modules.py; there, we concatenate 'w' to r (search for 'w') right now, it says in 2 places:

w = G.edata['w']
feat = torch.cat([w, r], -1)

replace this with something like:

if 'w' in G.edata.keys():
    w = G.edata['w']
    feat = torch.cat([w, r], -1)
else:
    feat = torch.cat([r, ], -1)
Chen-Cai-OSU commented 3 years ago

Thanks a lot! I actually spend two hours this morning on making up a small example. Now after reading your reply, it's working!

I still have a few quick questions:

Chen-Cai-OSU commented 3 years ago

Another question is about the construction of the graph. Although the paper is for point clouds, you still represent each molecule as a graph. I looked a few examples, it seems that you are using a complete graph?

The dataset I want to apply se3-transformer for is actually the graph, where the nodes has positions vector and some other features, and I want to achieve equivariance w.r.t the rotation (my output vector for each graph rotate if I rotate the input).

My current understanding is that there is no extra complication caused by switching from point clouds (complete graph) to a sparse graph. Just build the graph and provide the features should be enough. Is that correct? I am worried that there are some aspects I overlooked.

FabianFuchsML commented 3 years ago

what is the key x in node feature? Is it coordinates of each atom?

yes; it might be useful to store this depending on what operations you do later on; however, in the current implementation of the qm9 experiment, I believe we do not use it at all (we just use 'd')

For d, if I have node a and b, is it equal to a.pos - b.pos? a.pos meanings the coordinates of point a. It seems that b.pos-a.pos is also valid? I think it's related to how DGL handles directed edge.

We set up a directed graph here, so there is always a source and a destination; we define the relative position as x_dest - x_source

Most importantly, for d and w, if I want to incorporate say scaler feature with multiplicity 2 and vector feature with multiplicity 3, how should I organize the data? An example should be great.

Are you saying you need vector edge features? That is a little bit tricky and will require some coding; on the plus side, it does force you to understand what's really happening in the different parts of the code and the network. We currently feed edge as an input to the radial basis function. This works because the edge features in qm9 are scalars and therefore invariant wrt rotation. If you want rotating edge features (e.g. vectors), you need to handle them as features. Off the top of my head, you will probably need to go into qm9/models.py and then into the forward function of the model you are using. Here, you have an h_enc which is just updating node features in each layer. At this point, you will probably have to concatenate the edge features (and update your fibre structures accordingly).

FabianFuchsML commented 3 years ago

My current understanding is that there is no extra complication caused by switching from point clouds (complete graph) to a sparse graph. Just build the graph and provide the features should be enough. Is that correct? I am worried that there are some aspects I overlooked.

I assume by 'complete' you mean 'fully connected'? You can use our code for both fully connected and sparsely connected graphs. The qm9 example is fully connected.

Chen-Cai-OSU commented 3 years ago

Most importantly, for d and w, if I want to incorporate say scaler feature with multiplicity 2 and vector feature with multiplicity 3, how should I organize the data? An example should be great.

I think right now I can start with something simple (no vector feature on nodes and edges except the d) and later on try to include those. Thank you very much! I am closing the issue.