CompRhys / aviary

The Wren sits on its Roost in the Aviary.
MIT License
48 stars 11 forks source link

Wren: Why does averaging of augmented Wyckoff positions happen inside the NN, after message passing? #54

Closed sgbaird closed 2 years ago

sgbaird commented 2 years ago

https://www.science.org/doi/epdf/10.1126/sciadv.abn4117

The categorization of Wyckoff positions depends on a choice of origin (50). Hence, there is not a unique mapping between the crystal structure and the Wyckoff representation. To ensure that the model is invariant to the choice of origin, we perform on-the-fly augmentation of Wyckoff positions with respect to this choice of origin (see Fig. 6). The augmented representations are averaged at the end of the message passing stage to provide a single representation of equivalent Wyckoff representations to the output network. By pooling at this point, we ensure that the model is invariant and that its training is not biased toward materials for which many equivalent Wyckoff representations exist.

Probably a noob question here. I think I understand that it needs to happen at some point, but why does it need to happen after message passing? Why not implement this at the very beginning (i.e. in the input data representation)? Not so much doubtful of the choice as I am interested in the mechanics behind this choice. A topic that's come up in another context for me.

CompRhys commented 2 years ago

Not a noob question, I am not an expert on group theory/crystallography myself to spent a lot of time deliberating about these decisions. My reasoning was that averaging before the message passing would be the same as collapsing all the equivalent Wyckoff positions that can be relabelled to only be encodings of the site symmetry. Whereas averaging after allows the model to maintain the fact that sites with the same site symmetry are distinct. So give an explicit example here: https://www.cryst.ehu.es/cgi-bin/cryst/programs/nph-normsets?from=wycksets&gnum=68 shows in spg 68 we have 3 sets with the same site symmetry but that are distinct and so should have distinct embeddings if we want to make use of all the information in the representation.

janosh commented 2 years ago

My reasoning was that averaging before the message passing would be the same as collapsing all the equivalent Wyckoff positions that can be relabelled to only be encodings of the site symmetry. Whereas averaging after allows the model to maintain the fact that sites with the same site symmetry are distinct.

@CompRhys You're referring to the distinction made here, right?

CompRhys commented 2 years ago

The other thing that rereading isn't so clear (but is in the figure) is that we pool the token reps to get materials embeddings and then average the pooled materials embeddings of the equivalent representations.

janosh commented 2 years ago

A topic that's come up in another context for me.

@sgbaird I'd be curious what the other context is? And what the results were in case you tried directly averaging the input embedding rather than after message-passing/transformer-encoding. I wouldn't expect performance to be that much worse, maybe even the same.

sgbaird commented 2 years ago

@CompRhys thanks! I appreciate the discussion and clarification.

@janosh nothing that I've implemented directly. The two other places where this has gotten me thinking:

CDVAE manuscript, Section 4

PGNNDEC is parameterized by a SE(3) equivariant PGNN that inputs a multi-graph representation (section 3.1) of the noisy material structure and the latent representation. The node embedding for node i is obtained by the concatenation of the element embedding of a\~i and the latent representation z, followed by a MLP, h0 i = MLP(ea(\~ ai) k z), where k denotes concatenation of two vectors and ea is a learned embedding for elements. After K message-passing layers, PGNNDEC outputs a vector per node that is equivariant to the rotation of M\~ (emphasis added)

Vienna Summer School

During the lectures, there was discussion about equivariance vs. invariance and how which one you want depends on where you collapse things. For example, I think equivariance is implemented by collapsing the symmetric implementations later in the stack (i.e. the algorithm is allowed to distinguish between symmetric representations up until close to the end, e.g. a final pooling layer). In contrast, invariance maybe is implemented towards the beginning so that the model doesn't distinguish between symmetric representations and treats them as identical. Whether invariance vs. equivariance is desired depends on the application.

I wasn't clear on these topics, and I might be overgeneralizing, misremembering, or misinterpreting. I think it was James Spencer and/or David Pfau from DeepMind who talked about this, but maybe @halvarsu can correct me if I'm wrong.

sgbaird commented 2 years ago

Follow-up snippet from CDVAE manuscript, Appendix B.3:

GNN ARCHITECTURE We use DimeNet++ adapted for periodicity (Klicpera et al., 2020a;b) as the encoder, which is SE(3) invariant to the input structure. The decoder needs to output an vector per node that is SE(3) equivariant to the input structure. We use GemNet-dQ (Klicpera et al., 2021) as the decoder. We used implementations from the Open Catalysis Project (OCP) (Chanussot et al., 2021), but we reduced the size of hidden dimensions to 128 for faster training. The encoder has 2.2 million parameters and the decoder has 2.3 million parameters.

Interestingly, the encoder is SE(3) invariant while the decoder is SE(3) equivariant. I'm curious why.

janosh commented 2 years ago

Interestingly, the encoder is SE(3) invariant while the decoder is SE(3) equivariant. I'm curious why.

Haven't read CDVAE yet (though I'm planning to). Are you asking why the encoder is not also equivariant? Or why the decoder can't be invariant? If the 2nd, then the answer seems to be

The decoder needs to output an vector per node that is SE(3) equivariant to the input structure.

janosh commented 2 years ago

Though looking at the paper now, another paragraph seems to contradict the one you quoted:

To capture the necessary invariances and encode the interactions crossing periodic boundaries, we use SE(3) equivariant graph neural networks adapted with periodicity (PGNNs) for both the encoder and decoder of our VAE.

I'm guessing that's an oversight since they repeat elsewhere that the encoder is invariant.

sgbaird commented 2 years ago

@janosh good catch. @txie-93 or @kyonofx maybe you could confirm?

txie-93 commented 2 years ago

Thanks for pointing this out! Yes, this is indeed an oversight. The encoder is invariant because it encodes crystal into a latent vector. The latent vector is invariant to the symmetry operations.

sgbaird commented 2 years ago

Thanks @txie-93!