materialsvirtuallab / megnet

Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals
BSD 3-Clause "New" or "Revised" License
497 stars 155 forks source link

Review request for disordered structure example #385

Open resnant opened 1 year ago

resnant commented 1 year ago

Hi, I have been exploring the application of MEGNet to structures with site mixing. I have written a simple example using CrystalGraphDisordered class to handle these structures, and I would appreciate it if you could review it to confirm that my understanding is correct and there are no processing errors.

If this code is useful to others, I would like to submit a pull request as an example of processing disordered structures. I think that MEGNet's capability to handle disordered structures is a powerful advantage, especially for researchers dealing with experimentally determined crystal structures.

Thank you in advance.

from pymatgen.core import Structure, Lattice
from pymatgen.core.periodic_table import Element
from pymatgen.core.composition import Composition

from megnet.utils.models import load_model
from megnet.models import MEGNetModel
from megnet.data.graph import GaussianDistance
from megnet.data.crystal import CrystalGraph, CrystalGraphDisordered

# Load a pretrained MEGNet model
model_pretrained = load_model("Eform_MP_2019")

# Define a new MEGNet model with the same architecture as the pretrained model, but without the first embedding layer
# The CrystalGraphDisordered class calculates the weighted average of element embeddings of all atoms in each site depending on its site fraction
bond_converter = model_pretrained.graph_converter.bond_converter
cg_disorder = CrystalGraphDisordered(bond_converter=bond_converter, cutoff=5)
model_new = MEGNetModel(100, 2, 16, graph_converter=cg_disorder)
weights = model_pretrained.get_weights()
model_new.set_weights(weights[1:])

# An example of a disordered structure with site mixing
structure = Structure(Lattice.cubic(2.2),
            [Composition({"Fe":0.5, "Ni":0.5})], [[0, 0, 0]])

# Convert the crystal structure to a graph and make a prediction using the new MEGNet model
graph = model_new.graph_converter.convert(structure)
model_new.predict_graph(graph) # -> array([0.4069269], dtype=float32)

Related:

317

https://github.com/materialsvirtuallab/megnet/blob/master/notebooks/model_reconstruct.ipynb

shyuep commented 1 year ago

Thanks. This is a great addition. However, I would like to suggest for development to be postponed until the new reimplementation in the https://github.com/materialsvirtuallab/matgl repository to be completed. The implementation is actually already done but we need to port over the existing MEGNet models over. The new implementation is replace this TF implementation and the current repo will be deprecated.

resnant commented 1 year ago

Thanks for your response. I agree to postpone additional development of the current repo until the new implementation is completed. As a PyTorch user, I welcome the transition from TensorFlow to PyTorch and DGL. I'll check out the PyTorch implementation of MEGNet and M3GNet. Thank you again for the update, and please let me know if there are any other ways I can assist during the transition.