Closed xiehuanyi closed 10 months ago
Assuming there are no bugs in your code you seem to be running for just 3 epochs. What do you mean then when you say the loss does not converge? Could you maybe provide some graphs/numbers? As far as I can tell your example also does not provide information about the dataset you are using, so its hard to tell!
Is there any reason you are not using the provided torchmd-train utility? I find it quite featureful! Most of your code seem related to the Dataset, in which case the provided Custom class might be what you are looking for: https://github.com/torchmd/torchmd-net/blob/dca66796d00680a79a7b7c85d6704d30d15dc84c/torchmdnet/datasets/custom.py#L7-L31 If there is some functionality you are missing let us know so we consider adding it.
eval_loss epoch step
1 120946425856.0 0 499
2 126612365312.0 0 999
3 132942659584.0 0 1499
4 106031333376.0 0 1999
5 91581849600.0 0 2499
6 139767087104.0 0 2999
7 311060463616.0 0 3499
8 96562634752.0 0 3999
9 149120892928.0 0 4499
10 122255409152.0 0 4999
11 113533616128.0 0 5499
12 216064688128.0 0 5999
13 101125210112.0 0 6499
14 721705893888.0 0 6999
15 147561873408.0 0 7499
16 135868628992.0 0 7999
17 199190691840.0 0 8499
18 89850052608.0 0 8999
19 69984813056.0 0 9499
20 155524808704.0 0 9999
21 153351618560.0 1 10499
22 129449132032.0 1 10999
23 231814447104.0 1 11499
24 64472707072.0 1 11999
25 122545004544.0 1 12499
26 80596549632.0 1 12999
27 77701513216.0 1 13499
28 120324743168.0 1 13999
29 213828239360.0 1 14499
30 131378298880.0 1 14999
31 97231765504.0 1 15499
32 141100646400.0 1 15999
33 142118600704.0 1 16499
34 79063613440.0 1 16999
35 277870936064.0 1 17499
36 149919236096.0 1 17999
37 118923984896.0 1 18499
38 66087944192.0 1 18999
39 153931972608.0 1 19499
40 117702656000.0 1 19999
41 144633200640.0 1 20499
42 116774649856.0 2 20999
43 145603592192.0 2 21499
44 144544120832.0 2 21999
45 114598690816.0 2 22499
46 95283609600.0 2 22999
47 95340781568.0 2 23499
48 86069018624.0 2 23999
49 118724837376.0 2 24499
50 104463663104.0 2 24999
51 106144661504.0 2 25499
52 139560747008.0 2 25999
53 190854529024.0 2 26499
54 134427598848.0 2 26999
55 395008507904.0 2 27499
56 125777297408.0 2 27999
57 97119354880.0 2 28499
58 202578296832.0 2 28999
59 132108419072.0 2 29499
60 175519088640.0 2 29999
61 62381907968.0 2 30499
Key | Description |
---|---|
molecule_name | String, molecule identifier. |
atom_count | Integer, number of atoms. |
bond_count | Integer, number of bonds. |
elements | List, length equal to the number of atoms. Each element indicates the type of atom. For example, for a water molecule, elements=['H', 'H', 'O']. |
coordinates | List, length equal to the number of atoms. The i-th element is a 3-tuple representing the 3D coordinates (x, y, z) of the i-th atom. |
connectivity | List, length equal to the number of atoms. The i-th element is a list of all connected atoms to the i-th atom. |
edge_list | List, length equal to 2 times the number of bonds. Each element (i, j) represents an edge from atom i to atom j. |
edge_attr | List, length equal to 2 times the number of bonds. The value represents the bond type. '1': single bond, '2': double bond, '3': triple bond. |
formal_charge | List, length equal to the number of atoms. The i-th element represents the formal charge of the i-th atom, represented as a floating-point number. |
energy | Floating-point number, the predicted molecular energy. |
force | List, length equal to the number of atoms times 3. The predicted molecular force field. |
Here is a sample
{'mol_name': 1027776, 'atom_count': 33, 'bond_count': 34, 'connectivity': [[1, 12, 13, 14], [0, 2, 15, 16], [1, 3, 17], [2, 4, 10, 18], [3, 5, 19, 20], [4, 6, 9], [5, 7, 21, 22], [6, 8, 23, 24], [7, 9, 25, 26], [5, 8, 27, 28], [3, 11, 29, 30], [10, 12], [0, 11, 31, 32], [0], [0], [1], [1], [2], [3], [4], [4], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [12], [12]], 'edge_list': array([[ 0, 1],
[ 0, 12],
[ 0, 13],
[ 0, 14],
[ 1, 0],
[ 1, 2],
[ 1, 15],
[ 1, 16],
[ 2, 1],
[ 2, 3],
[ 2, 17],
[ 3, 2],
[ 3, 4],
[ 3, 10],
[ 3, 18],
[ 4, 3],
[ 4, 5],
[ 4, 19],
[ 4, 20],
[ 5, 4],
[ 5, 6],
[ 5, 9],
[ 6, 5],
[ 6, 7],
[ 6, 21],
[ 6, 22],
[ 7, 6],
[ 7, 8],
[ 7, 23],
[ 7, 24],
[ 8, 7],
[ 8, 9],
[ 8, 25],
[ 8, 26],
[ 9, 5],
[ 9, 8],
[ 9, 27],
[ 9, 28],
[10, 3],
[10, 11],
[10, 29],
[10, 30],
[11, 10],
[11, 12],
[12, 0],
[12, 11],
[12, 31],
[12, 32],
[13, 0],
[14, 0],
[15, 1],
[16, 1],
[17, 2],
[18, 3],
[19, 4],
[20, 4],
[21, 6],
[22, 6],
[23, 7],
[24, 7],
[25, 8],
[26, 8],
[27, 9],
[28, 9],
[29, 10],
[30, 10],
[31, 12],
[32, 12]]), 'coordinates': array([[ 3.4393e+00, -7.7700e-01, 7.9550e-01],
[ 2.5651e+00, -1.8494e+00, 1.0060e-01],
[ 1.4435e+00, -1.3965e+00, -7.0260e-01],
[ 4.7110e-01, -4.5660e-01, -1.4880e-01],
[-9.4080e-01, -9.0180e-01, -6.0550e-01],
[-2.0660e+00, -3.1410e-01, 8.9200e-02],
[-2.5093e+00, 1.0391e+00, -2.6650e-01],
[-4.0444e+00, 1.0275e+00, -9.6800e-02],
[-4.3085e+00, -2.1650e-01, 7.6410e-01],
[-3.2308e+00, -1.1761e+00, 2.6510e-01],
[ 8.0030e-01, 9.9480e-01, -5.4050e-01],
[ 2.2689e+00, 1.7525e+00, 2.4780e-01],
[ 3.6073e+00, 5.2050e-01, -1.6000e-03],
[ 4.4331e+00, -1.2202e+00, 9.9150e-01],
[ 3.0209e+00, -5.1680e-01, 1.7813e+00],
[ 3.2023e+00, -2.4492e+00, -5.7150e-01],
[ 2.2039e+00, -2.5559e+00, 8.7050e-01],
[ 1.7402e+00, -1.1098e+00, -1.6312e+00],
[ 4.8710e-01, -5.3890e-01, 9.5080e-01],
[-9.6530e-01, -1.9882e+00, -4.3720e-01],
[-1.0139e+00, -7.6400e-01, -1.7129e+00],
[-2.2289e+00, 1.3053e+00, -1.3069e+00],
[-2.0452e+00, 1.7985e+00, 3.8930e-01],
[-4.5326e+00, 9.0850e-01, -1.0781e+00],
[-4.4302e+00, 1.9579e+00, 3.4680e-01],
[-4.1453e+00, 7.3000e-03, 1.8314e+00],
[-5.3280e+00, -6.1570e-01, 6.5100e-01],
[-3.5596e+00, -1.6545e+00, -6.8810e-01],
[-3.0140e+00, -1.9922e+00, 9.7570e-01],
[-6.1000e-03, 1.6813e+00, -2.4850e-01],
[ 8.8990e-01, 1.0585e+00, -1.6406e+00],
[ 4.5003e+00, 1.0662e+00, 3.4150e-01],
[ 3.7507e+00, 3.2990e-01, -1.0795e+00]]), 'elements': [6, 6, 7, 6, 6, 7, 6, 6, 6, 6, 6, 16, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'formal_charge': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'edge_attr': ['1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1'], 'energy': -564947.2428974451, 'force': array([ 1.19542500e-02, 1.50895600e-02, 8.90907000e-03, -8.46083000e-03,
2.20930100e-02, -8.49166000e-03, 6.66573200e-02, -1.29301800e-02,
1.73378800e-02, 1.11164700e-01, 8.90480910e-01, 1.36219420e+00,
-7.11127990e-01, -1.89369298e+00, -1.88591875e+00, 1.48505939e+00,
2.38219630e+00, -5.40378200e-02, -9.02942690e-01, -1.59174288e+00,
4.12486320e-01, -1.30733100e-01, 7.06262600e-02, -2.19656570e-01,
-7.88368800e-02, -9.18358300e-02, 7.38398300e-02, 2.39523800e-02,
8.48454300e-02, 8.41398600e-02, -1.69240200e-02, 2.52864300e-02,
1.06474800e-02, 3.66975700e-02, -3.03143000e-02, 6.22389500e-02,
-5.76213700e-02, 3.27049000e-03, -1.21018870e-01, 9.97038000e-03,
-1.88818100e-02, -1.15973200e-02, 2.00237600e-02, -1.88924900e-02,
7.66396000e-03, -6.95602000e-03, 1.75572500e-02, 2.90763000e-03,
-1.41658100e-02, -5.48323000e-03, -3.46259000e-03, -1.54654700e-02,
-1.23171100e-02, 3.31721300e-02, 2.07115300e-02, -1.99455000e-03,
-1.07317000e-03, 6.45442000e-03, 7.82600000e-05, 3.60015000e-02,
-1.47148400e-02, 1.32800000e-02, 3.74008000e-03, 6.06890900e-02,
2.65872700e-02, 8.00926600e-02, -2.37152200e-02, 8.50707000e-03,
-2.26657600e-02, 5.92506900e-02, 2.46591300e-02, 5.49758000e-02,
1.50741400e-02, 3.18393000e-03, 4.44421600e-02, 3.03451400e-02,
-2.29207200e-02, -1.70236300e-02, -3.08657100e-02, 9.91725000e-02,
3.12973500e-02, 5.38457200e-02, 3.95838600e-02, -3.66786000e-02,
8.65600000e-05, -9.98935000e-03, 1.38373300e-02, -4.90941400e-02,
3.74482000e-03, 1.77920400e-02, 1.16827000e-02, -5.61604000e-03,
-1.07910000e-04, 1.15546300e-02, -1.42559300e-02, -9.62767000e-03,
2.64509000e-02, 6.26110000e-04, 3.36452700e-02])}
maybe it's because I need more epochs to converge. However, is there any suggestion of how to set the hyperparameters?
I cannot offer any more insights regarding hparams besides whats laid out in the paper describing the architecture. Maybe some other person can intercede there if you show us your current ones.
I do not know what your loss does not go down, just to mention some things that caught my eye:
Thanks! I will check them. And by the way, what would you suggest for such a large energy?
Why are the energy and force in your dataset so different in magnitude? perhaps it is a matter of inconsistent units?
We use kcal/mol as the unit of energy. should we convert the units to match the force and energy?
The model does not really care about units. What I am worried about is that the sheer difference in magnitude for the numbers you are summing to compute the loss is causing numerical accuracy issues.
I also tried only calculate the energy and normalize the energy. It doesn't seem working for me. May be I should train some other models?
I would start by trying to reproduce a known result. For instance, run torchmd-train with the ET-MD17.yaml example to get some data for the val or train loss (note that it will be MSE loss). Then compare with your script and see if you are getting a similar convergence, otherwise debug. This way you will be able to discern if your error is due to the dataset, hyperparameters or a bug in your script.
I tried equiformer provided, but it doesn't seem converge when I use a constant learning rate. I don't know if I have had any error in my code.