ngruver / NOS

Protein Design with Guided Discrete Diffusion
https://arxiv.org/abs/2305.20009
MIT License
116 stars 10 forks source link

Having trouble replicating training process from README #1

Closed jusjosgra closed 1 year ago

jusjosgra commented 1 year ago

Hello,

Firstly, thanks very much for the work that has gone into this, its a really interesting approach and I am keen to investigate it further.

I have been trying to replicate the work here but am currently struggling to fill in the gaps between the instructions to prepare the OAS dataset and the instructions to train a model. Looking at the hydra configs there are some undocumented steps around preparing a training dataset and validation dataset?

Could you consider updating the readme to facilitate replication please?

ngruver commented 1 year ago

Thanks for the note, Justin!

I've updated the processing script and added a script for creating splits. I've also added the splits that I used in my experiments in the "data" directory. Note that you'll have to modify the hydra arguments to match the below, as I had to break the train csv into pieces in order to add to the repo. data_dir: ./data trainfn: train*.csv

unkosan commented 1 year ago

Hello.

I have also attempted to follow the instructions provided in the README.md file. But I encountered some runtime errors.

First, I executed the following command for training the generative head.

CUDA_VISIBLE_DEVICES=1 PYTHONPATH="." python scripts/train_seq_model.py model=mlm model.optimizer.lr=0.0005 data_dir=/home/ryota-nakano/workspace/NOS-temp/data train_fn=train_*.csv val_fn=val_iid.csv vocab_file=/home/ryota-nakano/workspace/NOS-temp/vocab.txt log_dir=/home/ryota-nakano/workspace/NOS-temp/out

Upon executing this command, an error was thrown after the fourth epoch. The specific detail of the error is as follows:

Traceback (most recent call last):                                                                                                     
  File "/home/ryota-nakano/workspace/NOS-temp/scripts/train_seq_model.py", line 45, in main                                            
    trainer.fit(                                                                                                                       
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 531, in fit    
    call._call_and_handle_interrupt(                                                                                                   
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 41, in _call_and_h
andle_interrupt                                                                                                                        
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)                                              
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py
", line 91, in launch                                                                                                                  
    return function(*args, **kwargs)                                                                                                   
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 570, in _fit_im
pl                                                                                                                                     
    self._run(model, ckpt_path=ckpt_path)                                                                                              
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 975, in _run   
    results = self._run_stage()                                                                                                        
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1018, in _run_s
tage                                                                                                                                   
    self.fit_loop.run()                                                                                                                
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run     
    self.advance()                                                                                                                     
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance 
    self.epoch_loop.run(self._data_fetcher)                                                                                            
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 134, 
in run                           
    self.on_advance_end()
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 248, 
in on_advance_end                
    self.val_loop.run()          
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 177, in _decora
tor                              
    return loop_run(self, *args, **kwargs)                         
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 122, in r
un                               
    return self.on_run_end()                                       
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 244, in o
n_run_end                        
    self._on_evaluation_epoch_end()                                
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 325, in _
on_evaluation_epoch_end          
    call._call_callback_hooks(trainer, hook_name)                  
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 189, in _call_call
back_hooks                       
    fn(trainer, trainer.lightning_module, *args, **kwargs)                                                                             
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/trainer.py", line 100, in on_validation_epoch_end
    _, log = sample_model(                                         
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/sample.py", line 151, in sample_model                                         
    seed_log, seed_wandb_log = metrics.evaluate_samples(                                                                               
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/metrics.py", line 215, in evaluate_samples                                    
    samp_df = labeler.label_seqs(s_for_labels)                     
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/metrics.py", line 98, in label_seqs                                           
    return pd.DataFrame([self.label_seq(s) for s in seqs])                                                                             
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/metrics.py", line 98, in <listcomp>                                           
    return pd.DataFrame([self.label_seq(s) for s in seqs])                                                                             
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/metrics.py", line 78, in label_seq                                            
    ss_frac = X.secondary_structure_fraction()
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/Bio/SeqUtils/ProtParam.py", line 324, in secondary_structur
e_fraction                       
    aa_percentages = self.get_amino_acids_percent()                                                                                    
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/Bio/SeqUtils/ProtParam.py", line 112, in get_amino_acids_pe
rcent                            
    percentages = {aa: count / self.length for aa, count in aa_counts.items()}                                                         
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/Bio/SeqUtils/ProtParam.py", line 112, in <dictcomp>
    percentages = {aa: count / self.length for aa, count in aa_counts.items()}                                                         
ZeroDivisionError: division by zero

On debugging, I found that modifying line 371 of the mlm_diffusion.py file to samples = traj[-1] did not resolve the issue completely. Instead, it prolonged the occurrence of the same error until epoch 356.

Subsequently, I executed the following command, according to the instructions provided for “vanilla infilling experiments,” using the checkpoint from epoch 303. Since the file “poas_seeds.csv” was not available, I used the “poas_seeds.txt” file located in the root directory as a substitute.

CUDA_VISIBLE_DEVICES=1 PYTHONPATH="." python scripts/infill/run_diffusion.py model=mlm ckpt_path="/home/ryota-nakano/workspace/NOS-temp/out/ar_mlm_test/models/best_by_train/epoch\=303-step\=216144.ckpt" +seeds_fn=/home/ryota-nakano/workspace/NOS-temp/poas_seeds.txt +results_dir=infill/ data_dir=/home/ryota-nakano/workspace/NOS-temp/data vocab_file=/home/ryota-nakano/workspace/NOS-temp/vocab.txt log_dir=/home/ryota-nakano/workspace/NOS-temp/infill_out

However, the following error happened. It seems that the poas_seeds.txt file is not the correct one. It appears that there are missing files required for replication. Please verify and provide the necessary files for replication.

Traceback (most recent call last):
  File "/home/ryota-nakano/workspace/NOS-temp/scripts/infill/run_diffusion.py", line 71, in <module>
    main()
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/ryota-nakano/workspace/NOS-temp/scripts/infill/run_diffusion.py", line 59, in main
    sample_outer_loop(
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/sample.py", line 337, in sample_outer_loop
    make_sampling_csv(
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/sample.py", line 65, in make_sampling_csv
    for vh, vl in seeds:
ValueError: not enough values to unpack (expected 2, got 1)
ngruver commented 1 year ago

Apologies in being so slow to reply! I've added the missing csv file here: https://github.com/ngruver/NOS/blob/main/poas_seeds.csv

The csv has should be correctly broken into vh and vl and therefore not lead to the ValueError you included above.

Please let me know if you run into any other issues!

Nate