usnistgov / alignn

Atomistic Line Graph Neural Network https://scholar.google.com/citations?user=9Q-tNnwAAAAJ&hl=en
https://jarvis.nist.gov/jalignn/
Other
220 stars 80 forks source link

Classification model prediction #85

Open knc6 opened 1 year ago

knc6 commented 1 year ago

Add an example to make predictions with trained model. Something like the following:

from alignn.models.alignn import ALIGNN, ALIGNNConfig
import torch
import pprint
from alignn.config import TrainingConfig
from jarvis.core.atoms import Atoms
from jarvis.core.graphs import Graph
from jarvis.db.jsonutils import dumpjson, loadjson

device = "cpu"
if torch.cuda.is_available():
    device = torch.device("cuda")

filename = "checkpoint_100.pt"
cutoff = 8
max_neighbors = 12
config = loadjson("config.json")
print(pprint.pprint(config))
config = TrainingConfig(**config)
model = ALIGNN(config.model)
model.load_state_dict(torch.load(filename, map_location=device)["model"])
model.to(device)
model.eval()

atoms = Atoms.from_poscar("POSCAR")
g, lg = Graph.atom_dgl_multigraph(
    atoms,
    cutoff=float(cutoff),
    max_neighbors=max_neighbors,
)
out_data = (
    torch.argmax(model([g.to(device), lg.to(device)]))
    .detach()
    .cpu()
    .numpy()
    .flatten()
    .tolist()
)[0]
print("out_data class ", out_data)
Halfnote24 commented 6 months ago

Can I make predictions on a new dataset of CIFs with a model trained from an old dataset?