Closed JDRadatti closed 1 year ago
Resolved: I was using different parameters than the ones I used to finetune ZINC. I also needed to update my custom dataset slightly.
here is the bash file that worked if you're curious:
#!/usr/bin/env bash
n_gpu=1
epoch=4
max_epoch=5000
batch_size=64
tot_updates=$((33000*epoch/batch_size/n_gpu))
warmup_updates=$((tot_updates*16/100))
CUDA_VISIBLE_DEVICES=0 fairseq-train \
--user-dir ../../graphormer \
--num-workers 16 \
--ddp-backend=legacy_ddp \
--user-data-dir ../../graphormer/data/customized_dataset \
--dataset-name dataset_name \
--task graph_prediction \
--criterion l1_loss \
--arch graphormer_slim \
--num-classes 1 \
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \
--lr-scheduler polynomial_decay --power 1 --warmup-updates $warmup_updates --total-num-update $tot_updates \
--lr 2e-4 --end-learning-rate 1e-5 \
--batch-size $batch_size \
--fp16 \
--data-buffer-size 20 \
--encoder-layers 12 \
--encoder-embed-dim 80 \
--encoder-ffn-embed-dim 80 \
--encoder-attention-heads 8 \
--max-epoch $max_epoch \
--save-dir ./ckpts/dataset_name/ \
--pretrained-model-name zinc_graphormer_slim \
--seed 1 \
@JDRadatti Hi ! Could you elaborate on how to handle custom datasets?
Hi! I am trying to fine-tune ZINC with a custom dataset I made using custom data.
I used the zinc.sh script to train ZINC. I added checkpoint_best.pt to dropbox and added the link to PRETRAINED_MODEL_URL in pretrain/init.py. I am pretty sure I did this correctly because I was able to evaluate with a 0.06 mae. Then, I created a custom torch_geometric.data.InMemoryDataset and modified pyg_dataset_lookup_table.py to handle my new dataset.
Here is the command I am using to fine-tune:
Here is my error message:
It says I am missing "encoder.graph_encoder.final_layer_norm.weight" and "encoder.graph_encoder.final_layer_norm.bias" in the state_dict. I get the same error message when using the pre-trained models (i.e. pcqm4mv2_graphormer_base) on my custom dataset.