aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
461 stars 127 forks source link

Using custom trained models in `vibration_analysis.py` #612

Closed rschireman closed 2 years ago

rschireman commented 2 years ago

Hi all,

I recently trained a NNP from an ASE db file for an organic molecule (see #611 ). Is there a way to use models trained from examples/nnp_training_force.py in the vibrational analysis script? I want to replace:

model = torchani.models.ANI1x(periodic_table_index=True).to(device).double()

with a model loaded from 'force-training-best.pt'.

Thanks again, Ray

shubbey commented 2 years ago

It depends on how you saved the model, but if it's based off the default training script, you just need to load the saved network weights into whatever ANI model you originally used, eg

nn = torchani.ANIModel([H_network, C_network, N_network, O_network...])
best_pt = torch.load('best.pt')
nn.load_state_dict(best_pt)
model = torchani.nn.Sequential(aev_computer, nn).to(device)
rschireman commented 2 years ago

Thanks for the quick response! Now, I have an error computing the aev for my target molecule. The error I am obtaining is:

radial_aev.index_add_(0, index12[1], radial_terms_) IndexError: index out of range in self

My network is as follows:

device = torch.device('cpu')
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
species_order = ['S', 'C', 'H']

num_species = len(species_order)
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)
###############################################################################
# The code to define networks, optimizers, are mostly the same
aev_dim = aev_computer.aev_length

H_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 160),
    torch.nn.CELU(0.1),
    torch.nn.Linear(160, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

C_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 144),
    torch.nn.CELU(0.1),
    torch.nn.Linear(144, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

S_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

nn = torchani.ANIModel([H_network, C_network, S_network])
best_pt = torch.load('force-training-best.pt')
print(best_pt)
nn.load_state_dict(best_pt)
model = torchani.nn.Sequential(aev_computer, nn).to(device)
print(model)

molecule = ase.io.read("BTBT.pdb")
print(molecule.get_chemical_symbols())

BTBT.pdb is a 24 atom organic molecule and my dataset contains 10,000 off-equilibrium structures of this molecule with accompanying forces and energies. My suspicion is that species_order is incorrect, but I have no idea where this error could be coming from.

shubbey commented 2 years ago

You probably want to make your species_order = ['H','C','S'] to match your ANIModel network order.

rschireman commented 2 years ago

retrained with species_order=['H','C','S'] and the same error is still occurring. The error indicates that this line is the culprit:

energies = model((species, coordinates)).energies

rschireman commented 2 years ago

the order of atoms from the pdb file is:

['S', 'C', 'C', 'H', 'C', 'H', 'C', 'H', 'C', 'H', 'C', 'C', 'S', 'C', 'C', 'H', 'C', 'H', 'C', 'H', 'C', 'C', 'C', 'H']

rschireman commented 2 years ago

The solution was changing the species to torchani's indices:

species = torch.tensor([[2, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0]],device=device,dtype=torch.long)