Open ShangChien opened 1 year ago
Error: metrics function has NaN value When the batch size increases, it is more likely to encounter this error, and no matter what type of task.
Hello ShangChien Thanks for your interest in our project. However, we can't repeat this error. Could you provide a minimal case for us to repeat this error?
from unimol_tools import MolTrain, MolPredict
import numpy as np
import pickle
with open('data.pkl','rb') as f:
raw = pickle.load(f)
data={
'target':raw['target'],
'atoms':raw['atoms'],
'coordinates':raw['coordinates'],
}
clf = MolTrain(task='regression', data_type='molecule', epochs=10, batch_size=112, metrics='r2')
pred = clf.fit(data = data)
# clf = MolPredict(load_model='../exp')
# res = clf.predict(data = data)
$ python fit.py
2023-09-11 10:57:05 | unimol_tools/data/datareader.py | 138 | INFO | Uni-Mol(QSAR) | Anomaly clean with 3 sigma threshold: 18093 -> 17953
2023-09-11 10:57:13 | unimol_tools/train.py | 88 | INFO | Uni-Mol(QSAR) | Output directory already exists: ./exp
2023-09-11 10:57:13 | unimol_tools/train.py | 89 | INFO | Uni-Mol(QSAR) | Warning: Overwrite output directory: ./exp
2023-09-11 10:57:14 | unimol_tools/models/unimol.py | 114 | INFO | Uni-Mol(QSAR) | Loading pretrained weights from /share/home/wangjiawen/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/weights/mol_pre_all_h_220816.pt
2023-09-11 10:57:15 | unimol_tools/models/nnmodel.py | 103 | INFO | Uni-Mol(QSAR) | start training Uni-Mol:unimolv1
2023-09-11 10:58:52 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [1/10] train_loss: 1.0411, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000093, 95.7s
2023-09-11 11:00:26 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [2/10] train_loss: 1.0379, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000082, 92.4s
2023-09-11 11:01:57 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [3/10] train_loss: 1.0391, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000072, 91.9s
2023-09-11 11:03:29 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [4/10] train_loss: 1.0318, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000062, 91.7s
2023-09-11 11:05:01 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [5/10] train_loss: 1.0389, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000052, 91.4s
2023-09-11 11:06:32 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [6/10] train_loss: 1.0415, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000041, 90.9s
2023-09-11 11:08:03 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [7/10] train_loss: 1.0436, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000031, 91.6s
2023-09-11 11:09:35 | unimol_tools/tasks/trainer.py | 169 | INFO | Uni-Mol(QSAR) | Epoch [8/10] train_loss: 1.0356, val_loss: 1.0057, val_r2: 0.0275, lr: 0.000021, 92.0s
val: 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 32/33 [00:16<00:00, 2.44it/s, Epoch=Epoch 9/10, loss=nan]Traceback (most recent call last):
File "/share/home/csq/csq/unimol/Uni-Mol/unimol_tools/unimol_tools/eg_jobs/tg/fit.py", line 15, in <module>
pred = clf.fit(data = data)
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/train.py", line 56, in fit
self.model.run()
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/models/nnmodel.py", line 120, in run
_y_pred = self.trainer.fit_predict(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/tasks/trainer.py", line 157, in fit_predict
y_preds, val_loss, metric_score = self.predict(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/tasks/trainer.py", line 254, in predict
metric_score = self.metrics.cal_metric(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/utils/metrics.py", line 197, in cal_metric
return self.cal_reg_metric(label, predict, nan_value)
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/utils/metrics.py", line 175, in cal_reg_metric
res_dict[metric_type] = nan_metric(label, predict)
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/utils/metrics.py", line 173, in nan_metric
def nan_metric(label, predict): return cal_nan_metric(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/unimol_tools-1.0.0-py3.10.egg/unimol_tools/utils/metrics.py", line 49, in cal_nan_metric
result.append(metric_func(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/sklearn/utils/_param_validation.py", line 211, in wrapper
return func(*args, **kwargs)
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/sklearn/metrics/_regression.py", line 989, in r2_score
y_type, y_true, y_pred, multioutput = _check_reg_targets(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/sklearn/metrics/_regression.py", line 101, in _check_reg_targets
y_pred = check_array(y_pred, ensure_2d=False, dtype=dtype)
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/sklearn/utils/validation.py", line 959, in check_array
_assert_all_finite(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/sklearn/utils/validation.py", line 124, in _assert_all_finite
_assert_all_finite_element_wise(
File "/share/home/csq/anaconda3/envs/unimol/lib/python3.10/site-packages/sklearn/utils/validation.py", line 173, in _assert_all_finite_element_wise
raise ValueError(msg_err)
ValueError: Input contains NaN.
sorry for delay @HongshuaiWang1
metrics function has NaN value: https://github.com/dptech-corp/Uni-Mol/blob/b640dda4397f276d0873d74b332d7e4d1e9454f6/unimol_tools/unimol_tools/utils/metrics.py#L30-L51
multilabel_regression not support
rmse
metrics as tutorial on bohrium: https://github.com/dptech-corp/Uni-Mol/blob/b640dda4397f276d0873d74b332d7e4d1e9454f6/unimol_tools/unimol_tools/utils/metrics.py#L100-L113Last, multilabel_regression's loss is hard to decrease.
Looking for solution, thank you! 👀