torchmd / torchmd-net

Training neural network potentials
MIT License
335 stars 75 forks source link

Loss do not converge #212

Closed xiehuanyi closed 10 months ago

xiehuanyi commented 1 year ago

I tried equiformer provided, but it doesn't seem converge when I use a constant learning rate. I don't know if I have had any error in my code.

from torchmdnet.models.model import load_model
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random
import torch
import yaml
from torch import nn
from torchmdnet.models.torchmd_et import TorchMD_ET
import pytorch_lightning as pl
from tqdm import tqdm
import time

with open('/code/ckpts/exp/log.txt', 'w') as f:
    f.write('')

class Molset(Dataset):
    def __init__(self, data, train_radio=0.95, mode='train'):
        super().__init__()
        self.mode = mode
        self.train_radio = train_radio
        self.data = data

        length = len(self.data)
        indices = np.random.permutation(length)
        train_size = int(length * self.train_radio)

        if self.mode == 'train':
            self.data = self.data[indices[:train_size]]
        else:
            self.data = self.data[indices[train_size:]]

    def __getitem__(self, idx):
        data = self.data[idx]
        z, pos, energy, forces = data['elements'], data['coordinates'], data['energy'], data['force']
        z = np.array(z)
        energy = np.array([energy])
        forces = forces.reshape([-1, 3])
        return z, pos, energy, forces

    def __len__(self):
        return len(self.data)

def collate_fn(batch):
    zs = [i[0] for i in batch]
    poses = [i[1] for i in batch]
    energy = [i[2] for i in batch]
    forces = [i[3] for i in batch]
    batch_ind = []
    for idx, sample in enumerate(batch):
        batch_ind += [idx] * len(sample[0])

    zs = np.concatenate(zs, axis=0)
    poses = np.concatenate(poses, axis=0)
    energy = np.concatenate(energy, axis=0)
    forces = np.concatenate(forces, axis=0)
    batch_ind = np.array(batch_ind)
    return torch.tensor(zs, dtype=torch.long), \
            torch.tensor(poses, dtype=torch.float32), \
            torch.tensor(energy, dtype=torch.float32), \
            torch.tensor(forces, dtype=torch.float32), \
            torch.tensor(batch_ind, dtype=torch.long)

# @torch.no_grad()
def eval(model, eval_loader):
    model.eval()
    losses = []
    for batch in eval_loader:
        z, pos, energy, forces, batch_ind = batch
        z, pos, energy, forces, batch_ind = z.to(device), pos.to(device), energy.to(device), forces.to(device), batch_ind.to(device)
        pred_ene, pred_for = model(z, pos, batch_ind)
        loss_ene = loss_fn(pred_ene, energy)
        loss_for = loss_fn(pred_for, forces)
        loss = loss_ene + 10 * loss_for
        losses.append(loss.detach())
    model.train()
    return sum(losses) / len(losses)

data = np.load('d.npy', allow_pickle=True)
train_loader = DataLoader(Molset(data), batch_size=32, collate_fn=collate_fn, num_workers=8)
eval_loader = DataLoader(Molset(data, mode='eval'), batch_size=32, collate_fn=collate_fn, num_workers=8)

model = load_model("ckpts/et-md17/epoch=2139-val_loss=0.2543-test_loss=0.2317.ckpt", 
                                derivative=True, max_num_neighbors=64)

device = 'cuda'
opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
loss_fn = nn.L1Loss()
epochs = 3
steps = 1
log_steps = 5000

model = model.to(device)
for epoch in range(epochs):
    train_tqdm = tqdm(train_loader)
    t0 = time.time()
    for batch in train_tqdm:
        opt.zero_grad()
        z, pos, energy, forces, batch_ind = batch
        z, pos, energy, forces, batch_ind = z.to(device), pos.to(device), energy.to(device), forces.to(device), batch_ind.to(device)
        pred_ene, pred_for = model(z, pos, batch_ind)
        loss_ene = loss_fn(pred_ene, energy)
        loss_for = loss_fn(pred_for, forces)
        loss = loss_ene + 100 * loss_for
        loss.backward()
        opt.step()
        steps += 1
        if steps % log_steps == 0:
            l_eval = eval(model, eval_loader)
            train_tqdm.set_postfix({'loss_ene': loss_ene.item(), 
                                    'loss_for': loss_for.item(), 
                                    'train loss': loss.item(), 
                                    'eval loss': l_eval.item()})
            with open('/code/ckpts/exp/log.txt', 'a') as f:
                f.write(f'epoch: {epoch} steps: {steps} loss_ene: {loss_ene.item()}, loss_for: {loss_for.item()}, train loss: {loss.item()}, eval loss: {l_eval.item()}')

torch.save(model.state_dict(), '/code/ckpts/exp/model_parameters.pth')
RaulPPelaez commented 1 year ago

Assuming there are no bugs in your code you seem to be running for just 3 epochs. What do you mean then when you say the loss does not converge? Could you maybe provide some graphs/numbers? As far as I can tell your example also does not provide information about the dataset you are using, so its hard to tell!

Is there any reason you are not using the provided torchmd-train utility? I find it quite featureful! Most of your code seem related to the Dataset, in which case the provided Custom class might be what you are looking for: https://github.com/torchmd/torchmd-net/blob/dca66796d00680a79a7b7c85d6704d30d15dc84c/torchmdnet/datasets/custom.py#L7-L31 If there is some functionality you are missing let us know so we consider adding it.

xiehuanyi commented 1 year ago
  1. the evaluation loss is like this
    eval_loss   epoch   step
    1   120946425856.0  0   499
    2   126612365312.0  0   999
    3   132942659584.0  0   1499
    4   106031333376.0  0   1999
    5   91581849600.0   0   2499
    6   139767087104.0  0   2999
    7   311060463616.0  0   3499
    8   96562634752.0   0   3999
    9   149120892928.0  0   4499
    10  122255409152.0  0   4999
    11  113533616128.0  0   5499
    12  216064688128.0  0   5999
    13  101125210112.0  0   6499
    14  721705893888.0  0   6999
    15  147561873408.0  0   7499
    16  135868628992.0  0   7999
    17  199190691840.0  0   8499
    18  89850052608.0   0   8999
    19  69984813056.0   0   9499
    20  155524808704.0  0   9999
    21  153351618560.0  1   10499
    22  129449132032.0  1   10999
    23  231814447104.0  1   11499
    24  64472707072.0   1   11999
    25  122545004544.0  1   12499
    26  80596549632.0   1   12999
    27  77701513216.0   1   13499
    28  120324743168.0  1   13999
    29  213828239360.0  1   14499
    30  131378298880.0  1   14999
    31  97231765504.0   1   15499
    32  141100646400.0  1   15999
    33  142118600704.0  1   16499
    34  79063613440.0   1   16999
    35  277870936064.0  1   17499
    36  149919236096.0  1   17999
    37  118923984896.0  1   18499
    38  66087944192.0   1   18999
    39  153931972608.0  1   19499
    40  117702656000.0  1   19999
    41  144633200640.0  1   20499
    42  116774649856.0  2   20999
    43  145603592192.0  2   21499
    44  144544120832.0  2   21999
    45  114598690816.0  2   22499
    46  95283609600.0   2   22999
    47  95340781568.0   2   23499
    48  86069018624.0   2   23999
    49  118724837376.0  2   24499
    50  104463663104.0  2   24999
    51  106144661504.0  2   25499
    52  139560747008.0  2   25999
    53  190854529024.0  2   26499
    54  134427598848.0  2   26999
    55  395008507904.0  2   27499
    56  125777297408.0  2   27999
    57  97119354880.0   2   28499
    58  202578296832.0  2   28999
    59  132108419072.0  2   29499
    60  175519088640.0  2   29999
    61  62381907968.0   2   30499
  2. the dataset is like this
Key Description
molecule_name String, molecule identifier.
atom_count Integer, number of atoms.
bond_count Integer, number of bonds.
elements List, length equal to the number of atoms. Each element indicates the type of atom. For example, for a water molecule, elements=['H', 'H', 'O'].
coordinates List, length equal to the number of atoms. The i-th element is a 3-tuple representing the 3D coordinates (x, y, z) of the i-th atom.
connectivity List, length equal to the number of atoms. The i-th element is a list of all connected atoms to the i-th atom.
edge_list List, length equal to 2 times the number of bonds. Each element (i, j) represents an edge from atom i to atom j.
edge_attr List, length equal to 2 times the number of bonds. The value represents the bond type. '1': single bond, '2': double bond, '3': triple bond.
formal_charge List, length equal to the number of atoms. The i-th element represents the formal charge of the i-th atom, represented as a floating-point number.
energy Floating-point number, the predicted molecular energy.
force List, length equal to the number of atoms times 3. The predicted molecular force field.

Here is a sample

{'mol_name': 1027776, 'atom_count': 33, 'bond_count': 34, 'connectivity': [[1, 12, 13, 14], [0, 2, 15, 16], [1, 3, 17], [2, 4, 10, 18], [3, 5, 19, 20], [4, 6, 9], [5, 7, 21, 22], [6, 8, 23, 24], [7, 9, 25, 26], [5, 8, 27, 28], [3, 11, 29, 30], [10, 12], [0, 11, 31, 32], [0], [0], [1], [1], [2], [3], [4], [4], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [12], [12]], 'edge_list': array([[ 0,  1],
       [ 0, 12],
       [ 0, 13],
       [ 0, 14],
       [ 1,  0],
       [ 1,  2],
       [ 1, 15],
       [ 1, 16],
       [ 2,  1],
       [ 2,  3],
       [ 2, 17],
       [ 3,  2],
       [ 3,  4],
       [ 3, 10],
       [ 3, 18],
       [ 4,  3],
       [ 4,  5],
       [ 4, 19],
       [ 4, 20],
       [ 5,  4],
       [ 5,  6],
       [ 5,  9],
       [ 6,  5],
       [ 6,  7],
       [ 6, 21],
       [ 6, 22],
       [ 7,  6],
       [ 7,  8],
       [ 7, 23],
       [ 7, 24],
       [ 8,  7],
       [ 8,  9],
       [ 8, 25],
       [ 8, 26],
       [ 9,  5],
       [ 9,  8],
       [ 9, 27],
       [ 9, 28],
       [10,  3],
       [10, 11],
       [10, 29],
       [10, 30],
       [11, 10],
       [11, 12],
       [12,  0],
       [12, 11],
       [12, 31],
       [12, 32],
       [13,  0],
       [14,  0],
       [15,  1],
       [16,  1],
       [17,  2],
       [18,  3],
       [19,  4],
       [20,  4],
       [21,  6],
       [22,  6],
       [23,  7],
       [24,  7],
       [25,  8],
       [26,  8],
       [27,  9],
       [28,  9],
       [29, 10],
       [30, 10],
       [31, 12],
       [32, 12]]), 'coordinates': array([[ 3.4393e+00, -7.7700e-01,  7.9550e-01],
       [ 2.5651e+00, -1.8494e+00,  1.0060e-01],
       [ 1.4435e+00, -1.3965e+00, -7.0260e-01],
       [ 4.7110e-01, -4.5660e-01, -1.4880e-01],
       [-9.4080e-01, -9.0180e-01, -6.0550e-01],
       [-2.0660e+00, -3.1410e-01,  8.9200e-02],
       [-2.5093e+00,  1.0391e+00, -2.6650e-01],
       [-4.0444e+00,  1.0275e+00, -9.6800e-02],
       [-4.3085e+00, -2.1650e-01,  7.6410e-01],
       [-3.2308e+00, -1.1761e+00,  2.6510e-01],
       [ 8.0030e-01,  9.9480e-01, -5.4050e-01],
       [ 2.2689e+00,  1.7525e+00,  2.4780e-01],
       [ 3.6073e+00,  5.2050e-01, -1.6000e-03],
       [ 4.4331e+00, -1.2202e+00,  9.9150e-01],
       [ 3.0209e+00, -5.1680e-01,  1.7813e+00],
       [ 3.2023e+00, -2.4492e+00, -5.7150e-01],
       [ 2.2039e+00, -2.5559e+00,  8.7050e-01],
       [ 1.7402e+00, -1.1098e+00, -1.6312e+00],
       [ 4.8710e-01, -5.3890e-01,  9.5080e-01],
       [-9.6530e-01, -1.9882e+00, -4.3720e-01],
       [-1.0139e+00, -7.6400e-01, -1.7129e+00],
       [-2.2289e+00,  1.3053e+00, -1.3069e+00],
       [-2.0452e+00,  1.7985e+00,  3.8930e-01],
       [-4.5326e+00,  9.0850e-01, -1.0781e+00],
       [-4.4302e+00,  1.9579e+00,  3.4680e-01],
       [-4.1453e+00,  7.3000e-03,  1.8314e+00],
       [-5.3280e+00, -6.1570e-01,  6.5100e-01],
       [-3.5596e+00, -1.6545e+00, -6.8810e-01],
       [-3.0140e+00, -1.9922e+00,  9.7570e-01],
       [-6.1000e-03,  1.6813e+00, -2.4850e-01],
       [ 8.8990e-01,  1.0585e+00, -1.6406e+00],
       [ 4.5003e+00,  1.0662e+00,  3.4150e-01],
       [ 3.7507e+00,  3.2990e-01, -1.0795e+00]]), 'elements': [6, 6, 7, 6, 6, 7, 6, 6, 6, 6, 6, 16, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'formal_charge': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'edge_attr': ['1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1'], 'energy': -564947.2428974451, 'force': array([ 1.19542500e-02,  1.50895600e-02,  8.90907000e-03, -8.46083000e-03,
        2.20930100e-02, -8.49166000e-03,  6.66573200e-02, -1.29301800e-02,
        1.73378800e-02,  1.11164700e-01,  8.90480910e-01,  1.36219420e+00,
       -7.11127990e-01, -1.89369298e+00, -1.88591875e+00,  1.48505939e+00,
        2.38219630e+00, -5.40378200e-02, -9.02942690e-01, -1.59174288e+00,
        4.12486320e-01, -1.30733100e-01,  7.06262600e-02, -2.19656570e-01,
       -7.88368800e-02, -9.18358300e-02,  7.38398300e-02,  2.39523800e-02,
        8.48454300e-02,  8.41398600e-02, -1.69240200e-02,  2.52864300e-02,
        1.06474800e-02,  3.66975700e-02, -3.03143000e-02,  6.22389500e-02,
       -5.76213700e-02,  3.27049000e-03, -1.21018870e-01,  9.97038000e-03,
       -1.88818100e-02, -1.15973200e-02,  2.00237600e-02, -1.88924900e-02,
        7.66396000e-03, -6.95602000e-03,  1.75572500e-02,  2.90763000e-03,
       -1.41658100e-02, -5.48323000e-03, -3.46259000e-03, -1.54654700e-02,
       -1.23171100e-02,  3.31721300e-02,  2.07115300e-02, -1.99455000e-03,
       -1.07317000e-03,  6.45442000e-03,  7.82600000e-05,  3.60015000e-02,
       -1.47148400e-02,  1.32800000e-02,  3.74008000e-03,  6.06890900e-02,
        2.65872700e-02,  8.00926600e-02, -2.37152200e-02,  8.50707000e-03,
       -2.26657600e-02,  5.92506900e-02,  2.46591300e-02,  5.49758000e-02,
        1.50741400e-02,  3.18393000e-03,  4.44421600e-02,  3.03451400e-02,
       -2.29207200e-02, -1.70236300e-02, -3.08657100e-02,  9.91725000e-02,
        3.12973500e-02,  5.38457200e-02,  3.95838600e-02, -3.66786000e-02,
        8.65600000e-05, -9.98935000e-03,  1.38373300e-02, -4.90941400e-02,
        3.74482000e-03,  1.77920400e-02,  1.16827000e-02, -5.61604000e-03,
       -1.07910000e-04,  1.15546300e-02, -1.42559300e-02, -9.62767000e-03,
        2.64509000e-02,  6.26110000e-04,  3.36452700e-02])}

image

  1. the reason why I donot use the script you provide is that I have some package version conflict and it's more convenient to write it by myself
xiehuanyi commented 1 year ago

maybe it's because I need more epochs to converge. However, is there any suggestion of how to set the hyperparameters?

RaulPPelaez commented 1 year ago

I cannot offer any more insights regarding hparams besides whats laid out in the paper describing the architecture. Maybe some other person can intercede there if you show us your current ones.

I do not know what your loss does not go down, just to mention some things that caught my eye:

  1. We typically use MSE loss for train and eval, you are using L1.
  2. In the sample you provided, energy seems to be really large compared with force (around 6 orders of magnitude), even when you seem to be trying to compensate for that by counting force error 100 larger than energy, maybe that is giving you some trouble. Note that by default out train script weights energy and force equally. Also note that you chose a different factor in eval.
xiehuanyi commented 1 year ago

Thanks! I will check them. And by the way, what would you suggest for such a large energy?

RaulPPelaez commented 1 year ago

Why are the energy and force in your dataset so different in magnitude? perhaps it is a matter of inconsistent units?

xiehuanyi commented 1 year ago

We use kcal/mol as the unit of energy. should we convert the units to match the force and energy?

RaulPPelaez commented 1 year ago

The model does not really care about units. What I am worried about is that the sheer difference in magnitude for the numbers you are summing to compute the loss is causing numerical accuracy issues.

xiehuanyi commented 1 year ago

I also tried only calculate the energy and normalize the energy. It doesn't seem working for me. May be I should train some other models?

RaulPPelaez commented 1 year ago

I would start by trying to reproduce a known result. For instance, run torchmd-train with the ET-MD17.yaml example to get some data for the val or train loss (note that it will be MSE loss). Then compare with your script and see if you are getting a similar convergence, otherwise debug. This way you will be able to discern if your error is due to the dataset, hyperparameters or a bug in your script.