ngruver / NOS

Protein Design with Guided Discrete Diffusion
MIT License
116 stars 10 forks source link

Having trouble replicating training process from README #1

Closed jusjosgra closed 1 year ago

jusjosgra commented 1 year ago


Firstly, thanks very much for the work that has gone into this, its a really interesting approach and I am keen to investigate it further.

I have been trying to replicate the work here but am currently struggling to fill in the gaps between the instructions to prepare the OAS dataset and the instructions to train a model. Looking at the hydra configs there are some undocumented steps around preparing a training dataset and validation dataset?

Could you consider updating the readme to facilitate replication please?

ngruver commented 1 year ago

Thanks for the note, Justin!

I've updated the processing script and added a script for creating splits. I've also added the splits that I used in my experiments in the "data" directory. Note that you'll have to modify the hydra arguments to match the below, as I had to break the train csv into pieces in order to add to the repo. data_dir: ./data trainfn: train*.csv

unkosan commented 1 year ago


I have also attempted to follow the instructions provided in the file. But I encountered some runtime errors.

First, I executed the following command for training the generative head.

CUDA_VISIBLE_DEVICES=1 PYTHONPATH="." python scripts/ model=mlm data_dir=/home/ryota-nakano/workspace/NOS-temp/data train_fn=train_*.csv val_fn=val_iid.csv vocab_file=/home/ryota-nakano/workspace/NOS-temp/vocab.txt log_dir=/home/ryota-nakano/workspace/NOS-temp/out

Upon executing this command, an error was thrown after the fourth epoch. The specific detail of the error is as follows:

Traceback (most recent call last):                                                                                                     
  File "/home/ryota-nakano/workspace/NOS-temp/scripts/", line 45, in main                                                                                                                                                              
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 531, in fit    
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 41, in _call_and_h
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)                                              
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/
", line 91, in launch                                                                                                                  
    return function(*args, **kwargs)                                                                                                   
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 570, in _fit_im
    self._run(model, ckpt_path=ckpt_path)                                                                                              
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 975, in _run   
    results = self._run_stage()                                                                                                        
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1018, in _run_s
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 201, in run     
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 354, in advance                                                                                            
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 134, 
in run                           
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 248, 
in on_advance_end                     
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 177, in _decora
    return loop_run(self, *args, **kwargs)                         
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 122, in r
    return self.on_run_end()                                       
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 244, in o
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/loops/", line 325, in _
    call._call_callback_hooks(trainer, hook_name)                  
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 189, in _call_call
    fn(trainer, trainer.lightning_module, *args, **kwargs)                                                                             
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 100, in on_validation_epoch_end
    _, log = sample_model(                                         
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 151, in sample_model                                         
    seed_log, seed_wandb_log = metrics.evaluate_samples(                                                                               
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 215, in evaluate_samples                                    
    samp_df = labeler.label_seqs(s_for_labels)                     
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 98, in label_seqs                                           
    return pd.DataFrame([self.label_seq(s) for s in seqs])                                                                             
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 98, in <listcomp>                                           
    return pd.DataFrame([self.label_seq(s) for s in seqs])                                                                             
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 78, in label_seq                                            
    ss_frac = X.secondary_structure_fraction()
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/Bio/SeqUtils/", line 324, in secondary_structur
    aa_percentages = self.get_amino_acids_percent()                                                                                    
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/Bio/SeqUtils/", line 112, in get_amino_acids_pe
    percentages = {aa: count / self.length for aa, count in aa_counts.items()}                                                         
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/Bio/SeqUtils/", line 112, in <dictcomp>
    percentages = {aa: count / self.length for aa, count in aa_counts.items()}                                                         
ZeroDivisionError: division by zero

On debugging, I found that modifying line 371 of the file to samples = traj[-1] did not resolve the issue completely. Instead, it prolonged the occurrence of the same error until epoch 356.

Subsequently, I executed the following command, according to the instructions provided for “vanilla infilling experiments,” using the checkpoint from epoch 303. Since the file “poas_seeds.csv” was not available, I used the “poas_seeds.txt” file located in the root directory as a substitute.

CUDA_VISIBLE_DEVICES=1 PYTHONPATH="." python scripts/infill/ model=mlm ckpt_path="/home/ryota-nakano/workspace/NOS-temp/out/ar_mlm_test/models/best_by_train/epoch\=303-step\=216144.ckpt" +seeds_fn=/home/ryota-nakano/workspace/NOS-temp/poas_seeds.txt +results_dir=infill/ data_dir=/home/ryota-nakano/workspace/NOS-temp/data vocab_file=/home/ryota-nakano/workspace/NOS-temp/vocab.txt log_dir=/home/ryota-nakano/workspace/NOS-temp/infill_out

However, the following error happened. It seems that the poas_seeds.txt file is not the correct one. It appears that there are missing files required for replication. Please verify and provide the necessary files for replication.

Traceback (most recent call last):
  File "/home/ryota-nakano/workspace/NOS-temp/scripts/infill/", line 71, in <module>
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/", line 94, in decorated_main
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/", line 394, in _run_hydra
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/", line 457, in _run_app
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/", line 223, in run_and_report
    raise ex
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/", line 220, in run_and_report
    return func()
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/", line 458, in <lambda>
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/_internal/", line 132, in run
    _ = ret.return_value
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/core/", line 260, in return_value
    raise self._return_value
  File "/home/ryota-nakano/miniconda3/envs/nos/lib/python3.10/site-packages/hydra/core/", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/ryota-nakano/workspace/NOS-temp/scripts/infill/", line 59, in main
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 337, in sample_outer_loop
  File "/home/ryota-nakano/workspace/NOS-temp/seq_models/", line 65, in make_sampling_csv
    for vh, vl in seeds:
ValueError: not enough values to unpack (expected 2, got 1)
ngruver commented 1 year ago

Apologies in being so slow to reply! I've added the missing csv file here:

The csv has should be correctly broken into vh and vl and therefore not lead to the ValueError you included above.

Please let me know if you run into any other issues!
