alibaba / graph-gpt

Graph Learning with Generative Pretrained Transformers
MIT License
54 stars 3 forks source link

Task_type choices #3

Open yangzhang33 opened 3 weeks ago

yangzhang33 commented 3 weeks ago

Hello, appreciate the works, when running the pcqm4m_v2_pretrain.sh, I noticed there is a task_type="pretrain-mlm" # pretrain pretrain-mlm pretrain-ltp pretrain-euler it seems we are pertaining using mlm objective, can you explain what are the other choice means? How to enable ntp task as described in the paper?

zhaoqf123 commented 3 weeks ago

Hello, appreciate the works, when running the pcqm4m_v2_pretrain.sh, I noticed there is a task_type="pretrain-mlm" # pretrain pretrain-mlm pretrain-ltp pretrain-euler it seems we are pertaining using mlm objective, can you explain what are the other choice means? How to enable ntp task as described in the paper?

Thank you for your interest in our work.

The functions that defining the task_type are in the script https://github.com/alibaba/graph-gpt/blob/main/src/utils/tokenizer_utils.py.

Briefly speaking, pretrain is for NTP, pretrain-mlm is for MLM of BERT or Scheduled Masked Token Prediction (SMTP) in MaskGIT.

pretrain-ltp's ltp means last token prediction, pretrain-euler means predicting another Eulerian Sequence of the same graph, like "translating" one eulerian path to another.

To enable the NTP pre-train, just use task_type pretrain. In the (outdated) paper, the sequence is prolonged as in https://github.com/alibaba/graph-gpt/blob/main/pic/serializing.png. To use it, set tokenizer_class="GSTTokenizer".

By the way, in the PCQM4Mv2 data, our experiments show that the results obtained for NTP (0.09) is worse than SMTP (0.086) in the same model size.

Our best results in valid dataset is 0.0844 (leaderboard 0.0856), and after adding valid data in fine-tune, achieving 0.0802 in leaderboard. Currently, GraphGPT does not support adding 3D data yet.

If you have any more questions, feel free to ask. @yangzhang33

zhaoqf123 commented 3 weeks ago

By the way, models trained with SMTP can also be used to generate new samples. See MaskGIT for details.

yangzhang33 commented 1 week ago

Hello, appreciate for the response, it solved my question. And if I want to reproduce the results, is it correct to run pcqm4m_v2_pretrain.sh then pcqm4m_v2_supervised.sh, by using the default settings?

Thanks in advance for your kind help.

zhaoqf123 commented 1 week ago

The hps in the bash file for reproducing are as follows.

The key hps are max-lr , path-dropout and total_tokens.

To save the time, we recommend to start with mini or small models, and gradually increase the model size. To help you keep in the track, we also provide our fine-tune results of different model size.

In fine-tune stage, true_valid=-1 means NO valid data is added to train. true_valid=10000 means we holds out 10000 samples from the valid data to validate the model performance, and add the remaining 63545 samples to train.

If you have any questions, feel free to contact us. @yangzhang33

pre-train fine-tune
model-size mini/small/medium/base/base24/base48 mini/small/medium/base/base24/base48
batch-size 1024/1024/1024/1024/8192/8192 1024
total_tokens/epochs 1/1/1/1/4/4*10^9 tokens 32 epochs
warmup_tokens/epochs_warmup 1*10^8 tokens 9.6 epochs
lr scheduler Warmup & linear decay Warmup & cosine decay
max-lr 3.0E-04 6/6/6/6/2/1.5 \dot 10^-4
min-lr 0 automatic set
Adam-betas [0.9, 0.95] [0.9, 0.99]
Adam-eps 1.0E-08 1.0E-10
max_grad_norm 5 1
weight-decay 0.1 0.02
attention-dropout 0.1 0.1
path-dropout 0 0/0/0/0.05/0.1/0.2
embed-dropout 0 0
mlp-dropout 0 0
layer_scale_init_val NA 1
EMA NA NA
hidden-act gelu gelu
max-position-embedding 1024 1024
tie_word_embeddings FALSE NA
ft-results w. true_valid=-1 NA 0.107/0.096/0.0886/0.0864/0.0845/0.0844
ft-results w. true_valid=10000 NA 0.102/0.09/0.083/0.081/0.079/0.079
ft-results w. true_valid=-1 on leaderboard NA NA/NA/NA/NA/NA/0.0856
ft-results w. true_valid=10000 on leaderboard NA NA/NA/NA/NA/NA/0.0804
yangzhang33 commented 1 week ago

@zhaoqf123 Many thanks for the detailed explanation, this helps a lot, I will try to test them, thanks