aqlaboratory / openfold

Trainable, memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2
Apache License 2.0
2.73k stars 514 forks source link

Finetune from AF parameters #421

Open C-de-Furina opened 6 months ago

C-de-Furina commented 6 months ago

Hi, I'm trying to train this model from an AF-multimer pretrained parameters. I notice that although it resumes from jax parameters, the loss is still pretty high and this model is far away from to be useful. Is this correct? My arguments are below.

--template_release_dates_cache_path=/home/xiaoguo.chen/openfold/mmcif_cache.json \ --precision=bf16 \ --gpus=2 --replace_sampler_ddp=True \ --seed=42 \ --deepspeed_config_path=/home/xiaoguo.chen/openfold/deepspeed_config.json \ --checkpoint_every_epoch \ --train_chain_data_cache_path=/home/xiaoguo.chen/openfold/chain_data_cache.json \ --obsolete_pdbs_file_path=/home/xiaoguo.chen/xiaoguo.chen/AFdata/pdb_mmcif/obsolete.dat \ --max_epochs=20 \ --train_epoch_len=100\ --config_preset="model_5_multimer_v3" \ --num_nodes=1 \ --resume_from_jax_params=/home/xiaoguo.chen/xiaoguo.chen/AFdata/params/params_model_5_multimer_v3.npz \ --train_mmcif_data_cache_path=/home/xiaoguo.chen/openfold/mmcif_cache.json image

christinaflo commented 5 months ago

Hi, so for multimer you only need the train_mmcif_data_cache_path, not the train_chain_data_cache_path, that is only required for monomer. Are you using precomputed alignments for training, I do not see it in the arguments listed? Also, could you share your curves over a few epochs?

C-de-Furina commented 5 months ago

Hi, so for multimer you only need the train_mmcif_data_cache_path, not the train_chain_data_cache_path, that is only required for monomer. Are you using precomputed alignments for training, I do not see it in the arguments listed? Also, could you share your curves over a few epochs?

Thank you. Yes I use precomputed alignments. Mgnify_hits and uniref90_hits are downloaded and uniprots_hits are generated locally with .cif files.

As for the curve, I've seen many others show their loss curves but I'm not sure how to generate that, can you tell me how to do that?

christinaflo commented 5 months ago

I use weights and biases. If you make an account and pass the parameters --wandb, --experiment_name, --wandb_project, --wandb_entity to the training script it will log it for you, as well as all the validation metrics. It'll be helpful to see how those metrics look for your finetuning run.

C-de-Furina commented 5 months ago

我使用权重和偏差。如果您创建一个帐户并将参数 --wandb、--experiment_name、--wandb_project、--wandb_entity 传递给训练脚本,它将为您记录它以及所有验证指标。了解这些指标对于您的微调运行有何影响将很有帮助。

I'm trying to generate it and I will post curves later. Btw, I downloaded some alignments from RODA, and then I used .cif files which are from AF dataset to generate entire MSA for lost chains and uniprot_hits for all chains.

Which means, Assume an AB.cif from AF dataset includes two chains A and B. If there are A-mgnify_hits and A-uniref90_hits in RODA but no alignment for B, then I generate A-uniprots_hits and all three B-hits.

Is this feasible? I'm not sure if I can combine them together.

C-de-Furina commented 5 months ago

I use weights and biases. If you make an account and pass the parameters --wandb, --experiment_name, --wandb_project, --wandb_entity to the training script it will log it for you, as well as all the validation metrics. It'll be helpful to see how those metrics look for your finetuning run.

Hi, I've generated some curves.

image image Compare to OF paper, lddt_ca and drmsd_ca are similer but all other losses are extremelly higher.

Further, I found that the parameters become chaos after several steps training. The first prediction is generated with AF parameters, the second one is with a checkpoint after three steps training, and the third one is after 100 steps. image image image

It seems that the training conpletely breaks pretrained parameters instead of 'finetune', so I'm pretty comfused about this. Do I make any mistake in arguments? Thank you.

vetmax7 commented 5 months ago

@C-de-Furina Hello!

Could you explain how did you use the checkpoint? As a extra tuning --openfold_checkpoint_path checkpoint.ckpt for alphafold model_1 (for example) or a model itself?

C-de-Furina commented 5 months ago

@C-de-Furina Hello!

Could you explain how did you use the checkpoint? As a extra tuning for alphafold model_1 (for example) or a model itself?--openfold_checkpoint_path checkpoint.ckpt

If you mean how do I generate prediction, then I use these two arguments at the same time.

--config_preset model_5_multimer_v3 --openfold_checkpoint_path checkpoint.ckpt

vetmax7 commented 5 months ago

@C-de-Furina Hello! Could you explain how did you use the checkpoint? As a extra tuning for alphafold model_1 (for example) or a model itself?--openfold_checkpoint_path checkpoint.ckpt

If you mean how do I generate prediction, then I use these two arguments at the same time.

--config_preset model_5_multimer_v3 --openfold_checkpoint_path checkpoint.ckpt

Ok, thanks!

one more question: Have you tried to convert checkpoint .ckpt file to model .npz itself? and use your own model instead of model_5 of AlphaFold?

C-de-Furina commented 5 months ago

@C-de-Furina Hello! Could you explain how did you use the checkpoint? As a extra tuning for alphafold model_1 (for example) or a model itself?--openfold_checkpoint_path checkpoint.ckpt

If you mean how do I generate prediction, then I use these two arguments at the same time. --config_preset model_5_multimer_v3 --openfold_checkpoint_path checkpoint.ckpt

Ok, thanks!

one more question: Have you tried to convert checkpoint .ckpt file to model .npz itself? and use your own model instead of model_5 of AlphaFold?

No, actually I even don't know how to do what you said.

vetmax7 commented 5 months ago

Ok, thanks!

abhinavb22 commented 1 week ago

Did any of you encounter problem with deepspeed during multi-node/multi-GPU training like this - #486 ? Any suggestions would be really helpful