rampasek / GraphGPS

Recipe for a General, Powerful, Scalable Graph Transformer
MIT License
643 stars 114 forks source link

pretrained models #6

Closed msadegh97 closed 2 years ago

msadegh97 commented 2 years ago

Hi @rampasek , Do you have any plan to release your checkpoints for the PCQM4M dataset?

rampasek commented 2 years ago

Hi, I didn't think of it but it is a good idea! However I don't have the checkpoints at the moment.

yanzhangnlp commented 2 years ago

Hi @rampasek ,

Have you tried to finetune the pretrained checkpoints on the OGBG-MolPCBA and OGBG-MolHIV datasets like Graphormer?

Best, Yan

rampasek commented 2 years ago

Hi @yanzhangnlp,

I briefly tired a basic fine-tuning and didn't get to a competitive performance in the initial experiment. For Graphormer, the authors used FLAG [1] and considerable hyper-parameter search (see Appendix B.2 of the original Graphormer paper). Perhaps that would be needed here too.

Best, Ladislav

[1] Kong et al. Flag: Adversarial data augmentation for graph neural networks. arXiv:2010.09891, 2020

yanzhangnlp commented 2 years ago

Hi @rampasek

Thank you for your reply! I am also trying to reproduce the results of GPS on the ogbg-code2 dataset. I used the suggested hparams but the results are horrible. May I ask if you have any suggestions to resolve this issue? I attach the training log below for your reference.

Best, Yan

train: {'epoch': 0, 'time_epoch': 1247.29074, 'eta': 36171.43156, 'eta_hours': 10.04762, 'loss': 8.51734692, 'lr': 0.0, 'params': 12454066, 'time_iter': 0.09783, 'precision': 0.0004024746553718845, 'recall': 0.0008608895973593895, 'f1': 0.0005234065686151103} val: {'epoch': 0, 'time_epoch': 25.32429, 'loss': 8.52540509, 'lr': 0, 'params': 12454066, 'time_iter': 0.03547, 'precision': 0.0003856773458386291, 'recall': 0.000872156725248718, 'f1': 0.000510236380420892} test: {'epoch': 0, 'time_epoch': 25.43904, 'loss': 8.52666314, 'lr': 0, 'params': 12454066, 'time_iter': 0.03708, 'precision': 0.0003827227993439038, 'recall': 0.000905169795273677, 'f1': 0.0005157211417462921}

Epoch 0: took 1337.2s (avg 1337.2s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 1, 'time_epoch': 1225.67292, 'eta': 34621.49132, 'eta_hours': 9.61708, 'loss': 4.04092359, 'lr': 5e-05, 'params': 12454066, 'time_iter': 0.09613, 'precision': 4.902249151910896e-06, 'recall': 1.0721813107247544e-05, 'f1': 6.1268387714061e-06} val: {'epoch': 1, 'time_epoch': 21.24162, 'loss': 3.51326043, 'lr': 0, 'params': 12454066, 'time_iter': 0.02975, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 1, 'time_epoch': 19.80584, 'loss': 3.51786912, 'lr': 0, 'params': 12454066, 'time_iter': 0.02887, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 1: took 1310.7s (avg 1323.9s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 2, 'time_epoch': 1168.52969, 'eta': 32773.44022, 'eta_hours': 9.10373, 'loss': 3.60856701, 'lr': 0.0001, 'params': 12454066, 'time_iter': 0.09165, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 2, 'time_epoch': 21.33373, 'loss': 3.52175495, 'lr': 0, 'params': 12454066, 'time_iter': 0.02988, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 2, 'time_epoch': 19.83572, 'loss': 3.52446977, 'lr': 0, 'params': 12454066, 'time_iter': 0.02892, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 2: took 1252.1s (avg 1300.0s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 3, 'time_epoch': 1154.85252, 'eta': 31176.2482, 'eta_hours': 8.66007, 'loss': 3.60591446, 'lr': 9.969e-05, 'params': 12454066, 'time_iter': 0.09058, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 3, 'time_epoch': 21.28122, 'loss': 3.52078049, 'lr': 0, 'params': 12454066, 'time_iter': 0.02981, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 3, 'time_epoch': 20.07109, 'loss': 3.52291928, 'lr': 0, 'params': 12454066, 'time_iter': 0.02926, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 3: took 1238.0s (avg 1284.5s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 4, 'time_epoch': 1169.38908, 'eta': 29828.67479, 'eta_hours': 8.28574, 'loss': 3.60370015, 'lr': 9.875e-05, 'params': 12454066, 'time_iter': 0.09172, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 4, 'time_epoch': 21.32615, 'loss': 3.5170684, 'lr': 0, 'params': 12454066, 'time_iter': 0.02987, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 4, 'time_epoch': 19.89412, 'loss': 3.51704681, 'lr': 0, 'params': 12454066, 'time_iter': 0.029, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 4: took 1252.4s (avg 1278.1s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 5, 'time_epoch': 1153.10677, 'eta': 28475.36692, 'eta_hours': 7.90982, 'loss': 3.6027395, 'lr': 9.719e-05, 'params': 12454066, 'time_iter': 0.09044, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 5, 'time_epoch': 21.27702, 'loss': 3.51542168, 'lr': 0, 'params': 12454066, 'time_iter': 0.0298, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 5, 'time_epoch': 20.09558, 'loss': 3.51573104, 'lr': 0, 'params': 12454066, 'time_iter': 0.02929, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 5: took 1236.4s (avg 1271.1s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 6, 'time_epoch': 1150.5119, 'eta': 27170.73337, 'eta_hours': 7.54743, 'loss': 3.60229837, 'lr': 9.505e-05, 'params': 12454066, 'time_iter': 0.09024, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 6, 'time_epoch': 21.47542, 'loss': 3.51785756, 'lr': 0, 'params': 12454066, 'time_iter': 0.03008, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 6, 'time_epoch': 19.84784, 'loss': 3.51806281, 'lr': 0, 'params': 12454066, 'time_iter': 0.02893, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 6: took 1233.6s (avg 1265.8s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 7, 'time_epoch': 1151.67597, 'eta': 25907.8314, 'eta_hours': 7.19662, 'loss': 3.60141069, 'lr': 9.234e-05, 'params': 12454066, 'time_iter': 0.09033, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 7, 'time_epoch': 21.75196, 'loss': 3.51536052, 'lr': 0, 'params': 12454066, 'time_iter': 0.03046, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 7, 'time_epoch': 22.25281, 'loss': 3.5150048, 'lr': 0, 'params': 12454066, 'time_iter': 0.03244, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 7: took 1237.5s (avg 1262.2s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 8, 'time_epoch': 1157.28921, 'eta': 24682.74389, 'eta_hours': 6.85632, 'loss': 3.60045299, 'lr': 8.909e-05, 'params': 12454066, 'time_iter': 0.09077, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 8, 'time_epoch': 21.40249, 'loss': 3.51734276, 'lr': 0, 'params': 12454066, 'time_iter': 0.02998, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 8, 'time_epoch': 19.89875, 'loss': 3.51758443, 'lr': 0, 'params': 12454066, 'time_iter': 0.02901, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 8: took 1240.7s (avg 1259.8s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 9, 'time_epoch': 1164.11908, 'eta': 23484.87578, 'eta_hours': 6.52358, 'loss': 3.60004186, 'lr': 8.536e-05, 'params': 12454066, 'time_iter': 0.0913, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 9, 'time_epoch': 21.41826, 'loss': 3.51544353, 'lr': 0, 'params': 12454066, 'time_iter': 0.03, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 9, 'time_epoch': 19.97905, 'loss': 3.51644452, 'lr': 0, 'params': 12454066, 'time_iter': 0.02912, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 9: took 1247.9s (avg 1258.6s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 10, 'time_epoch': 1155.45853, 'eta': 22278.18473, 'eta_hours': 6.18838, 'loss': 3.5996656, 'lr': 8.117e-05, 'params': 12454066, 'time_iter': 0.09062, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 10, 'time_epoch': 21.24367, 'loss': 3.5129272, 'lr': 0, 'params': 12454066, 'time_iter': 0.02975, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 10, 'time_epoch': 19.9556, 'loss': 3.51498483, 'lr': 0, 'params': 12454066, 'time_iter': 0.02909, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 10: took 1238.5s (avg 1256.8s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 11, 'time_epoch': 1170.66834, 'eta': 21102.84714, 'eta_hours': 5.8619, 'loss': 3.5995302, 'lr': 7.66e-05, 'params': 12454066, 'time_iter': 0.09182, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 11, 'time_epoch': 21.53683, 'loss': 3.51358189, 'lr': 0, 'params': 12454066, 'time_iter': 0.03016, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 11, 'time_epoch': 20.321, 'loss': 3.51575557, 'lr': 0, 'params': 12454066, 'time_iter': 0.02962, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 11: took 1255.2s (avg 1256.7s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 12, 'time_epoch': 1278.1951, 'eta': 20068.83982, 'eta_hours': 5.57468, 'loss': 3.59918574, 'lr': 7.169e-05, 'params': 12454066, 'time_iter': 0.10025, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 12, 'time_epoch': 23.27895, 'loss': 3.51180459, 'lr': 0, 'params': 12454066, 'time_iter': 0.0326, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 12, 'time_epoch': 21.99339, 'loss': 3.51467095, 'lr': 0, 'params': 12454066, 'time_iter': 0.03206, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 12: took 1369.6s (avg 1265.4s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 13, 'time_epoch': 1248.19923, 'eta': 18965.66753, 'eta_hours': 5.26824, 'loss': 3.59919278, 'lr': 6.651e-05, 'params': 12454066, 'time_iter': 0.0979, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 13, 'time_epoch': 21.47053, 'loss': 3.51341491, 'lr': 0, 'params': 12454066, 'time_iter': 0.03007, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 13, 'time_epoch': 19.93969, 'loss': 3.51547576, 'lr': 0, 'params': 12454066, 'time_iter': 0.02907, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 13: took 1334.2s (avg 1270.3s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 14, 'time_epoch': 1158.20876, 'eta': 17753.16785, 'eta_hours': 4.93144, 'loss': 3.59877294, 'lr': 6.113e-05, 'params': 12454066, 'time_iter': 0.09084, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 14, 'time_epoch': 21.2327, 'loss': 3.51464436, 'lr': 0, 'params': 12454066, 'time_iter': 0.02974, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 14, 'time_epoch': 20.0245, 'loss': 3.51710258, 'lr': 0, 'params': 12454066, 'time_iter': 0.02919, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 14: took 1241.6s (avg 1268.4s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005 train: {'epoch': 15, 'time_epoch': 1196.23409, 'eta': 16580.7267, 'eta_hours': 4.60576, 'loss': 3.59865425, 'lr': 5.56e-05, 'params': 12454066, 'time_iter': 0.09382, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} val: {'epoch': 15, 'time_epoch': 21.45243, 'loss': 3.51344415, 'lr': 0, 'params': 12454066, 'time_iter': 0.03005, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} test: {'epoch': 15, 'time_epoch': 19.95068, 'loss': 3.51468718, 'lr': 0, 'params': 12454066, 'time_iter': 0.02908, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0} Epoch 15: took 1281.3s (avg 1269.2s) | Best so far: epoch 0 train_loss: 8.5173 train_f1: 0.0005 val_loss: 8.5254 val_f1: 0.0005 test_loss: 8.5267 test_f1: 0.0005

rampasek commented 2 years ago

Hi @yanzhangnlp,

I'll look into it in the next few days and try to rerun ogbg-code2 with the current codebase (after the update to PyG2.0.4). Are you running the GPS config without any changes?

Here are the logs from 10 runs on which the paper results are based: https://github.com/rampasek/GraphGPS/blob/main/final-results.zip

For ogbg-code2 I also found a W&B plot: W B-code2

Sorry for the problems, I'll try to get back to you in the next few days. Ladislav

yanzhangnlp commented 2 years ago

Hi @rampasek ,

Yes, I use the default config file.

Best, Yan

rampasek commented 2 years ago

Hi @yanzhangnlp,

Here is W&B with 10 repeated runs (random seeds from 0 to 9) for the default ogbg_code2 config: https://wandb.ai/gtransformers/rerun_code2?workspace=user-rampasek (search for best/test_f1 panel)

So far it seems to be replicating the curve I posted before (currently the jobs are still running as I’m writing this).

I run the code that is on GitHub, no modification at all, I made a clean copy and re-downloaded the dataset. I use Python 3.9, PyTorch 1.10, PyG 2.0.4. I also ran conda list —export on my conda environment and I attach the output: requirements_GraphGPS.txt

Please let me know if you figure out what may be the issue, it seems to work ok on my end. Ladislav

rampasek commented 2 years ago

Hi @msadegh97, please see PR #11 for running inference using a pre-trained GPS model.

With that all questions in this issue are closed.