kuleshov-group / caduceus

Bi-Directional Equivariant Long-Range DNA Sequence Modeling
Apache License 2.0
137 stars 14 forks source link

Unexpected low test accuracy in reproducing dummy_mouse_enhancers_ensembl experiments #19

Closed zhan8855 closed 2 months ago

zhan8855 commented 2 months ago

Hi, thank you for your awesome work!

I would greatly appreciate it if you could help me check the following command and logs when I try to reproduce the experiments on dummy_mouse_enhancers_ensembl. The validation results are fine. However, I am getting an unexpected low test accuracy.

Thank you very much in advance!

Command

  python -m train \
    experiment=hg38/genomic_benchmark \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="dummy_mouse_enhancers_ensembl" \
    dataset.dest_path="${MY_DATA_PATH_to_dummy_mouse_enhancers_ensembl}" \
    dataset.train_val_split_seed=21 \
    dataset.batch_size=128 \
    dataset.rc_aug=false \
    +dataset.conjoin_train=false \
    +dataset.conjoin_test=false \
    model=caduceus \
    model._name_=dna_embedding_caduceus \
    +model.config_path="${MY_CONFIG_PATH_TO_model_config.json}" \
    +model.conjoin_test=false \
    +decoder.conjoin_train=true \
    +decoder.conjoin_test=false \
    optimizer.lr="6e-4" \
    trainer.max_epochs=100 \
    train.pretrained_model_path="${MY_PRETRAINED_PATH_TO_last.ckpt}" \
    wandb.group="downstream/gb_cv5" \
    wandb.job_type="dummy_mouse_enhancers_ensembl" \
    wandb.name="${WANDB_NAME}" \
    wandb.mode="offline" \
    wandb.id="gb_cv5_dummy_mouse_enhancers_ensembl_${WANDB_NAME}_seed-21" \
    +wandb.tags=\["seed-21"\] \
    hydra.run.dir="${HYDRA_RUN_DIR}"

Log

CONFIG
├── train
│   └── seed: 2222                                                                                                                            
│       interval: step                                                                                                                        
│       monitor: val/accuracy                                                                                                                 
│       mode: max                                                                                                                             
│       ema: 0.0                                                                                                                              
│       test: true                                                                                                                            
│       debug: false                                                                                                                          
│       ignore_warnings: false                                                                                                                
│       optimizer_param_grouping:                                                                                                             
│         bias_weight_decay: false                                                                                                            
│         normalization_weight_decay: false                                                                                                   
│       state:                                                                                                                                
│         mode: null                                                                                                                          
│         n_context: 0                                                                                                                        
│         n_context_eval: 0                                                                                                                   
│       ckpt: checkpoints/last.ckpt                                                                                                           
│       disable_dataset: false                                                                                                                
│       validate_at_start: false                                                                                                              
│       pretrained_model_path: /home/zhan8855/scratch/Caduceus/pretrain/hg38/caduceus_ps_seqlen-131k_d_model-256_n_layer-8_lr-8e-3/checkpoints
│       pretrained_model_strict_load: false                                                                                                   
│       pretrained_model_state_hook:                                                                                                          
│         _name_: load_backbone                                                                                                               
│         freeze_backbone: false                                                                                                              
│       post_init_hook:                                                                                                                       
│         _name_: null                                                                                                                        
│       layer_decay:                                                                                                                          
│         _name_: null                                                                                                                        
│         decay: 0.7                                                                                                                          
│       gpu_mem: 41                                                                                                                           
│       global_batch_size: 128                                                                                                                
│       cross_validation: true                                                                                                                
│       remove_test_loader_in_eval: true                                                                                                      
│                                                                                                                                             
├── wandb
│   └── project: dna                                                                                                                          
│       group: downstream/gb_cv5                                                                                                              
│       job_type: dummy_mouse_enhancers_ensembl                                                                                               
│       mode: offline                                                                                                                         
│       name: dummy_mouse_enhancers_ensembl_lr-6e-4_batch_size-128_rc_aug-false                                                               
│       save_dir: .                                                                                                                           
│       id: gb_cv5_dummy_mouse_enhancers_ensembl_dummy_mouse_enhancers_ensembl_lr-6e-4_batch_size-128_rc_aug-false_seed-21                    
│       tags:                                                                                                                                 
│       - seed-21                                                                                                                             
│                                                                                                                                             
├── trainer
│   └── _target_: pytorch_lightning.Trainer                                                                                                   
│       devices: 1                                                                                                                            
│       accelerator: gpu                                                                                                                      
│       accumulate_grad_batches: 1                                                                                                            
│       max_epochs: 100                                                                                                                       
│       gradient_clip_val: 1.0                                                                                                                
│       log_every_n_steps: 10                                                                                                                 
│       limit_train_batches: 1.0                                                                                                              
│       limit_val_batches: 1.0                                                                                                                
│       num_sanity_val_steps: 2                                                                                                               
│       num_nodes: 1                                                                                                                          
│       precision: 16                                                                                                                         
│                                                                                                                                             
├── loader
│   └── num_workers: 4                                                                                                                        
│       pin_memory: true                                                                                                                      
│       drop_last: true                                                                                                                       
│                                                                                                                                             
├── dataset
│   └── _name_: genomic_benchmark                                                                                                             
│       train_val_split_seed: 21                                                                                                              
│       dataset_name: dummy_mouse_enhancers_ensembl                                                                                           
│       dest_path: /home/zhan8855/scratch/caduceus_data/genomic_benchmark                                                                     
│       max_length: 1024                                                                                                                      
│       max_length_val: 1024                                                                                                                  
│       max_length_test: 1024                                                                                                                 
│       d_output: 2                                                                                                                           
│       use_padding: true                                                                                                                     
│       padding_side: left                                                                                                                    
│       add_eos: false                                                                                                                        
│       batch_size: 128                                                                                                                       
│       train_len: 1210                                                                                                                       
│       shuffle: true                                                                                                                         
│       dummy_mouse_enhancers_ensembl:                                                                                                        
│         train_len: 1210                                                                                                                     
│         classes: 2                                                                                                                          
│         max_length: 1024                                                                                                                    
│       demo_coding_vs_intergenomic_seqs:                                                                                                     
│         train_len: 100000                                                                                                                   
│         classes: 2                                                                                                                          
│         max_length: 200                                                                                                                     
│       demo_human_or_worm:                                                                                                                   
│         train_len: 100000                                                                                                                   
│         classes: 2                                                                                                                          
│         max_length: 200                                                                                                                     
│       human_enhancers_cohn:                                                                                                                 
│         train_len: 27791                                                                                                                    
│         classes: 2                                                                                                                          
│         max_length: 500                                                                                                                     
│       human_enhancers_ensembl:                                                                                                              
│         train_len: 154842                                                                                                                   
│         classes: 2                                                                                                                          
│         max_length: 512                                                                                                                     
│       human_ensembl_regulatory:                                                                                                             
│         train_len: 289061                                                                                                                   
│         classes: 3                                                                                                                          
│         max_length: 512                                                                                                                     
│       human_nontata_promoters:                                                                                                              
│         train_len: 36131                                                                                                                    
│         classes: 2                                                                                                                          
│         max_length: 251                                                                                                                     
│       human_ocr_ensembl:                                                                                                                    
│         train_len: 174756                                                                                                                   
│         classes: 2                                                                                                                          
│         max_length: 512                                                                                                                     
│       tokenizer_name: char                                                                                                                  
│       rc_aug: false                                                                                                                         
│       conjoin_train: false                                                                                                                  
│       conjoin_test: false                                                                                                                   
│                                                                                                                                             
├── task
│   └── _name_: multiclass                                                                                                                    
│       loss:                                                                                                                                 
│         _name_: cross_entropy                                                                                                               
│       metrics:                                                                                                                              
│       - accuracy                                                                                                                            
│       torchmetrics: null                                                                                                                    
│                                                                                                                                             
├── optimizer
│   └── _name_: adamw                                                                                                                         
│       lr: 0.0006                                                                                                                            
│       weight_decay: 0.1                                                                                                                     
│       betas:                                                                                                                                
│       - 0.9                                                                                                                                 
│       - 0.999                                                                                                                               
│                                                                                                                                             
├── scheduler
│   └── _name_: cosine_warmup_timm                                                                                                            
│       t_in_epochs: false                                                                                                                    
│       t_initial: 1000                                                                                                                       
│       lr_min: 5.9999999999999995e-05                                                                                                        
│       warmup_lr_init: 1.0e-06                                                                                                               
│       warmup_t: 10.0                                                                                                                        
│                                                                                                                                             
├── callbacks
│   └── learning_rate_monitor:                                                                                                                
│         logging_interval: step                                                                                                              
│       timer:                                                                                                                                
│         step: true                                                                                                                          
│         inter_step: false                                                                                                                   
│         epoch: true                                                                                                                         
│         val: true                                                                                                                           
│       params:                                                                                                                               
│         total: true                                                                                                                         
│         trainable: true                                                                                                                     
│         fixed: true                                                                                                                         
│       model_checkpoint:                                                                                                                     
│         monitor: val/accuracy                                                                                                               
│         mode: max                                                                                                                           
│         save_top_k: 1                                                                                                                       
│         save_last: false                                                                                                                    
│         dirpath: checkpoints/                                                                                                               
│         filename: val/accuracy                                                                                                              
│         auto_insert_metric_name: false                                                                                                      
│         verbose: true                                                                                                                       
│       model_checkpoint_every_n_steps:                                                                                                       
│         monitor: train/loss                                                                                                                 
│         mode: min                                                                                                                           
│         save_top_k: 0                                                                                                                       
│         save_last: true                                                                                                                     
│         dirpath: checkpoints/                                                                                                               
│         filename: train/loss                                                                                                                
│         auto_insert_metric_name: false                                                                                                      
│         verbose: true                                                                                                                       
│         every_n_train_steps: 5000                                                                                                           
│                                                                                                                                             
├── encoder
│   └── id                                                                                                                                    
├── decoder
│   └── _name_: sequence                                                                                                                      
│       mode: pool                                                                                                                            
│       conjoin_train: true                                                                                                                   
│       conjoin_test: false                                                                                                                   
│                                                                                                                                             
└── model
    └── _name_: dna_embedding_caduceus                                                                                                        
        config:                                                                                                                               
          _target_: caduceus.configuration_caduceus.CaduceusConfig                                                                            
          d_model: 128                                                                                                                        
          n_layer: 2                                                                                                                          
          vocab_size: 12                                                                                                                      
          ssm_cfg:                                                                                                                            
            d_state: 16                                                                                                                       
            d_conv: 4                                                                                                                         
            expand: 2                                                                                                                         
            dt_rank: auto                                                                                                                     
            dt_min: 0.001                                                                                                                     
            dt_max: 0.1                                                                                                                       
            dt_init: random                                                                                                                   
            dt_scale: 1.0                                                                                                                     
            dt_init_floor: 0.0001                                                                                                             
            conv_bias: true                                                                                                                   
            bias: false                                                                                                                       
            use_fast_path: true                                                                                                               
          rms_norm: true                                                                                                                      
          fused_add_norm: true                                                                                                                
          residual_in_fp32: false                                                                                                             
          pad_vocab_size_multiple: 8                                                                                                          
          norm_epsilon: 1.0e-05                                                                                                               
          initializer_cfg:                                                                                                                    
            initializer_range: 0.02                                                                                                           
            rescale_prenorm_residual: true                                                                                                    
            n_residuals_per_layer: 1                                                                                                          
          bidirectional: true,                                                                                                                
          bidirectional_strategy: add                                                                                                         
          bidirectional_weight_tie: true                                                                                                      
          rcps: false                                                                                                                         
          complement_map: null                                                                                                                
        config_path: /home/zhan8855/scratch/Caduceus/pretrain/hg38/caduceus_ps_seqlen-131k_d_model-256_n_layer-8_lr-8e-3/model_config.json    
        conjoin_test: false                                                                                                                   

[rank: 0] Global seed set to 2222
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id gb_cv5_dummy_mouse_enhancers_ensembl_dummy_mouse_enhancers_ensembl_lr-6e-4_batch_size-128_rc_aug-false_seed-21.
wandb: Tracking run with wandb version 0.13.5
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[2024-04-12 20:12:34,147][__main__][INFO] - Instantiating callback <pytorch_lightning.callbacks.LearningRateMonitor>
[2024-04-12 20:12:34,148][__main__][INFO] - Instantiating callback <src.callbacks.timer.Timer>
[2024-04-12 20:12:34,149][__main__][INFO] - Instantiating callback <src.callbacks.params.ParamsLog>
[2024-04-12 20:12:34,150][__main__][INFO] - Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
[2024-04-12 20:12:34,151][__main__][INFO] - Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
[2024-04-12 20:12:34,152][__main__][INFO] - Instantiating trainer <pytorch_lightning.Trainer>
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
**Using Char-level tokenizer**
already downloaded train-dummy_mouse_enhancers_ensembl
already downloaded test-dummy_mouse_enhancers_ensembl
**Using Char-level tokenizer**
already downloaded train-dummy_mouse_enhancers_ensembl
already downloaded test-dummy_mouse_enhancers_ensembl
...
key: shape MATCH, loading caduceus model weights
...
[2024-04-12 20:12:38,033][__main__][INFO] - Custom load_state_dict function is running.
/home/zhan8855/scratch/caduceus_env/miniconda3/envs/caduceus/lib/python3.8/site-packages/pytorch_lightning/core/saving.py:242: UserWarning: Found keys that are in the model state dict but not in the checkpoint: ['decoder.0.output_transform.weight', 'decoder.0.output_transform.bias']
  rank_zero_warn(
[2024-04-12 20:12:38,089][__main__][INFO] - config.train.ckpt='checkpoints/last.ckpt' fsspec_exists(config.train.ckpt)=False
**Using Char-level tokenizer**
already downloaded train-dummy_mouse_enhancers_ensembl
already downloaded test-dummy_mouse_enhancers_ensembl
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Hyperparameter groups: [{'weight_decay': 0.0}]
[2024-04-12 20:12:38,657][__main__][INFO] - Optimizer group 0 | 75 tensors | weight_decay 0.1
[2024-04-12 20:12:38,657][__main__][INFO] - Optimizer group 1 | 65 tensors | weight_decay 0.0

  | Name               | Type                      | Params
-----------------------------------------------------------------
0 | model              | DNAEmbeddingModelCaduceus | 3.9 M 
1 | encoder            | Identity                  | 0     
2 | decoder            | SequenceDecoder           | 514   
3 | train_torchmetrics | MetricCollection          | 0     
4 | val_torchmetrics   | MetricCollection          | 0     
5 | test_torchmetrics  | MetricCollection          | 0     
-----------------------------------------------------------------

3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
7.731     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.
/home/zhan8855/scratch/caduceus_env/miniconda3/envs/caduceus/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1595: PossibleUserWarning: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 0: 100%|█| 7/7 [00:07<00:00,  1.04s/it, loss=0.693, v_num=d-21, val/accuracy=0.639, val/loss=0.689, train/accuracy=0.501, train/loss=0.6Epoch 0, global step 6: 'val/accuracy' reached 0.63918 (best 0.63918), saving model to 'checkpoints/val/accuracy.ckpt' as top 1               
Epoch 1: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.69, v_num=d-21, val/accuracy=0.722, val/loss=0.678, train/accuracy=0.641, train/loss=0.68Epoch 1, global step 12: 'val/accuracy' reached 0.72165 (best 0.72165), saving model to 'checkpoints/val/accuracy.ckpt' as top 1              
Epoch 2: 100%|█| 7/7 [00:04<00:00,  1.71it/s, loss=0.686, v_num=d-21, val/accuracy=0.701, val/loss=0.668, train/accuracy=0.654, train/loss=0.6Epoch 2, global step 18: 'val/accuracy' was not in top 1                                                                                      
Epoch 3: 100%|█| 7/7 [00:04<00:00,  1.73it/s, loss=0.679, v_num=d-21, val/accuracy=0.711, val/loss=0.655, train/accuracy=0.661, train/loss=0.6Epoch 3, global step 24: 'val/accuracy' was not in top 1                                                                                      
Epoch 4: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.668, v_num=d-21, val/accuracy=0.732, val/loss=0.641, train/accuracy=0.684, train/loss=0.6Epoch 4, global step 30: 'val/accuracy' reached 0.73196 (best 0.73196), saving model to 'checkpoints/val/accuracy.ckpt' as top 1              
Epoch 5: 100%|█| 7/7 [00:03<00:00,  1.75it/s, loss=0.657, v_num=d-21, val/accuracy=0.742, val/loss=0.622, train/accuracy=0.689, train/loss=0.6Epoch 5, global step 36: 'val/accuracy' reached 0.74227 (best 0.74227), saving model to 'checkpoints/val/accuracy.ckpt' as top 1              
Epoch 6: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.644, v_num=d-21, val/accuracy=0.773, val/loss=0.601, train/accuracy=0.702, train/loss=0.6Epoch 6, global step 42: 'val/accuracy' reached 0.77320 (best 0.77320), saving model to 'checkpoints/val/accuracy.ckpt' as top 1              
Epoch 7: 100%|█| 7/7 [00:03<00:00,  1.77it/s, loss=0.627, v_num=d-21, val/accuracy=0.763, val/loss=0.573, train/accuracy=0.741, train/loss=0.6Epoch 7, global step 48: 'val/accuracy' was not in top 1                                                                                      
Epoch 8: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.601, v_num=d-21, val/accuracy=0.742, val/loss=0.556, train/accuracy=0.783, train/loss=0.5Epoch 8, global step 54: 'val/accuracy' was not in top 1                                                                                      
Epoch 9: 100%|█| 7/7 [00:03<00:00,  1.75it/s, loss=0.563, v_num=d-21, val/accuracy=0.732, val/loss=0.552, train/accuracy=0.842, train/loss=0.5Epoch 9, global step 60: 'val/accuracy' was not in top 1                                                                                      
Epoch 10: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.5, v_num=d-21, val/accuracy=0.701, val/loss=0.564, train/accuracy=0.910, train/loss=0.40Epoch 10, global step 66: 'val/accuracy' was not in top 1                                                                                     
Epoch 11: 100%|█| 7/7 [00:03<00:00,  1.77it/s, loss=0.421, v_num=d-21, val/accuracy=0.732, val/loss=0.535, train/accuracy=0.932, train/loss=0.Epoch 11, global step 72: 'val/accuracy' was not in top 1                                                                                     
Epoch 12: 100%|█| 7/7 [00:04<00:00,  1.75it/s, loss=0.336, v_num=d-21, val/accuracy=0.680, val/loss=0.617, train/accuracy=0.966, train/loss=0.Epoch 12, global step 78: 'val/accuracy' was not in top 1                                                                                     
Epoch 13: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.257, v_num=d-21, val/accuracy=0.732, val/loss=0.558, train/accuracy=0.975, train/loss=0.Epoch 13, global step 84: 'val/accuracy' was not in top 1                                                                                     
Epoch 14: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.195, v_num=d-21, val/accuracy=0.680, val/loss=0.613, train/accuracy=0.984, train/loss=0.Epoch 14, global step 90: 'val/accuracy' was not in top 1                                                                                     
Epoch 15: 100%|█| 7/7 [00:03<00:00,  1.76it/s, loss=0.145, v_num=d-21, val/accuracy=0.711, val/loss=0.615, train/accuracy=0.992, train/loss=0.Epoch 15, global step 96: 'val/accuracy' was not in top 1                                                                                     
Epoch 16: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.113, v_num=d-21, val/accuracy=0.732, val/loss=0.600, train/accuracy=0.991, train/loss=0.Epoch 16, global step 102: 'val/accuracy' was not in top 1                                                                                    
Epoch 17: 100%|█| 7/7 [00:03<00:00,  1.77it/s, loss=0.0883, v_num=d-21, val/accuracy=0.660, val/loss=0.748, train/accuracy=0.993, train/loss=0Epoch 17, global step 108: 'val/accuracy' was not in top 1                                                                                    
Epoch 18: 100%|█| 7/7 [00:04<00:00,  1.75it/s, loss=0.0733, v_num=d-21, val/accuracy=0.670, val/loss=0.771, train/accuracy=0.995, train/loss=0Epoch 18, global step 114: 'val/accuracy' was not in top 1                                                                                    
Epoch 19: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.0629, v_num=d-21, val/accuracy=0.711, val/loss=0.769, train/accuracy=0.995, train/loss=0Epoch 19, global step 120: 'val/accuracy' was not in top 1                                                                                    
Epoch 20: 100%|█| 7/7 [00:03<00:00,  1.76it/s, loss=0.0566, v_num=d-21, val/accuracy=0.680, val/loss=0.912, train/accuracy=0.992, train/loss=0Epoch 20, global step 126: 'val/accuracy' was not in top 1                                                                                    
Epoch 21: 100%|█| 7/7 [00:04<00:00,  1.75it/s, loss=0.0527, v_num=d-21, val/accuracy=0.680, val/loss=0.798, train/accuracy=0.995, train/loss=0Epoch 21, global step 132: 'val/accuracy' was not in top 1                                                                                    
Epoch 22: 100%|█| 7/7 [00:03<00:00,  1.77it/s, loss=0.0545, v_num=d-21, val/accuracy=0.629, val/loss=0.940, train/accuracy=0.992, train/loss=0Epoch 22, global step 138: 'val/accuracy' was not in top 1                                                                                    
Epoch 23: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.0477, v_num=d-21, val/accuracy=0.670, val/loss=0.890, train/accuracy=0.992, train/loss=0Epoch 23, global step 144: 'val/accuracy' was not in top 1                                                                                    
Epoch 24: 100%|█| 7/7 [00:04<00:00,  1.74it/s, loss=0.0459, v_num=d-21, val/accuracy=0.784, val/loss=0.707, train/accuracy=0.996, train/loss=0Epoch 24, global step 150: 'val/accuracy' reached 0.78351 (best 0.78351), saving model to 'checkpoints/val/accuracy.ckpt' as top 1            
..........
Running other epochs
..........
[2024-04-12 20:19:42,262][__main__][INFO] - Loaded best validation checkpoint from epoch 24
**Using Char-level tokenizer**
already downloaded train-dummy_mouse_enhancers_ensembl
already downloaded test-dummy_mouse_enhancers_ensembl
Restoring states from the checkpoint path at checkpoints/val/accuracy.ckpt
[2024-04-12 20:19:42,415][__main__][INFO] - Custom load_state_dict function is running.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at checkpoints/val/accuracy.ckpt
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.96it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric      ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test/accuracy       │    0.27272728085517883    │
│         test/loss         │     2.526049852371216     │
└───────────────────────────┴───────────────────────────┘
wandb: Waiting for W&B process to finish... (success).
wandb: 
wandb: Run history:
wandb:               epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:       test/accuracy ▁
wandb:           test/loss ▁
wandb:         timer/epoch █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:          timer/step ▂█▂█▂█▂▇▂█▅▇▂▇▂█▁▇▂▇▁▇▂█▃▇▃█▁▇▁▇▁▇▂▇▁▇▁▂
wandb:    timer/validation █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂
wandb:      train/accuracy ▁▃▄▄▇███████████████████████████████████
wandb:          train/loss ██▇▇▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:       trainer/epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb: trainer/global_step ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
wandb:        trainer/loss ██▇▇▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:      trainer/lr/pg1 ▇█████████▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁
wandb:      trainer/lr/pg2 ▇█████████▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁
wandb:        val/accuracy ▁▅▇▇▅▃▅▃▃▁▇▅▇▆▅█▇▆▇▇▇▆▆▆▇▇▇▇▆▆▅▆▅▅▅▅▅▅▅▅
wandb:            val/loss ▂▂▁▁▁▁▁▃▄▄▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▆▆▆▇▇▇█████
wandb: 
wandb: Run summary:
wandb:               epoch 24
wandb:       test/accuracy 0.27273
wandb:           test/loss 2.52605
wandb:         timer/epoch 4.08477
wandb:          timer/step 0.55261
wandb:    timer/validation 0.48716
wandb:      train/accuracy 1.0
wandb:          train/loss 0.00025
wandb:       trainer/epoch 99.0
wandb: trainer/global_step 600
wandb:        trainer/loss 0.00024
wandb:      trainer/lr/pg1 0.00025
wandb:      trainer/lr/pg2 0.00025
wandb:        val/accuracy 0.71134
wandb:            val/loss 1.38916
wandb: 
wandb: You can sync this run to the cloud by running:
wandb: wandb sync ./wandb/offline-run-20240412_201231-gb_cv5_dummy_mouse_enhancers_ensembl_dummy_mouse_enhancers_ensembl_lr-6e-4_batch_size-128_rc_aug-false_seed-21
wandb: Find logs at: ./wandb/offline-run-20240412_201231-gb_cv5_dummy_mouse_enhancers_ensembl_dummy_mouse_enhancers_ensembl_lr-6e-4_batch_size-128_rc_aug-false_seed-21/logs
zhan8855 commented 2 months ago

I fixed this by replacing

https://github.com/kuleshov-group/caduceus/blob/main/src/dataloaders/datasets/genomic_bench_dataset.py#L70

with

for i, x in enumerate(sorted(base_path.iterdir())):

yair-schiff commented 2 months ago

I am not sure I follow what issue the fix you are using is resolving to be honest.

In any case, you can also try playing with LR and batch size. I found these tasks to quite finicky and sometimes sensitive to those parameters with some runs leading to NaN values which could explain the low test loss

zhan8855 commented 2 months ago

The issue is that the genomics benchmark datasets have sub-directories "positive" and "negative" for both "train" and "test" splits, and it could happen that base_path.iterdir() returns ["negative", "positive"] when loading the "train" split and while returning ["positive", "negative"] when loading the "test" split. In that case, the test labels are not aligned with the train and eval labels.

yair-schiff commented 2 months ago

I see. Nice find. I wonder why I didn't hit this issue.