IdoAmos / not-from-scratch

MIT License
22 stars 0 forks source link

Questions about reproducing the results of Transformer model on Pathx task #1

Closed Cookize closed 6 months ago

Cookize commented 7 months ago

Hi, we tried to reproduce the results of Transformer model on Pathx as reported in the paper (Transformer + Causal SPT), but could not achieve the corresponding results.

The result reported in the paper is 88.05.

image

In our experiments, whether SPT is performed or not, the Transformer model only achieves random guessing on Pathx. We wondered if there was something wrong with the way we launched the experiments.


Here are some details.

First, we started SPT with the following configuration:

CONFIG
├── train
│   └── seed: 2222                                                              
│       interval: step                                                          
│       monitor: val/accuracy_ignore_m100                                       
│       mode: max                                                               
│       ema: 0.0                                                                
│       test: false                                                             
│       debug: false                                                            
│       ignore_warnings: false                                                  
│       state:                                                                  
│         mode: null                                                            
│         n_context: 0                                                          
│         n_context_eval: 0                                                     
│       ckpt: null                                                              
│       disable_dataset: false                                                  
│       validate_at_start: false                                                
│       pretrained_model_path: null                                             
│       pretrained_model_strict_load: true                                      
│       pretrained_model_state_hook:                                            
│         _name_: null                                                          
│       post_init_hook:                                                         
│         _name_: null                                                          
│       layer_decay:                                                            
│         _name_: null                                                          
│         decay: 0.7                                                            
│       manual_checkpoints: false                                               
│       ckpt_milestones: null                                                   
│       log_tensors_freq: -1                                                    
│                                                                               
├── tolerance
│   └── logdir: ./resume                                                        
│       id: null                                                                
│                                                                               
├── wandb
│   └── project: state-spaces                                                   
│       group: ''                                                               
│       job_type: training                                                      
│       mode: online                                                            
│       save_dir: null                                                          
│       id: null                                                                
│       entity: ''                                                              
│       name: ''                                                                
│                                                                               
├── trainer
│   └── gpus: 8                                                                 
│       accumulate_grad_batches: 1                                              
│       max_epochs: 200                                                         
│       gradient_clip_val: 0.0                                                  
│       log_every_n_steps: 10                                                   
│       limit_train_batches: 1.0                                                
│       limit_val_batches: 1.0                                                  
│       weights_summary: top                                                    
│       progress_bar_refresh_rate: 1                                            
│       track_grad_norm: -1                                                     
│       val_check_interval: 0.5                                                 
│                                                                               
├── optimizer
│   └── _name_: adamw                                                           
│       lr: 0.0005                                                              
│       weight_decay: 0.0                                                       
│       betas:                                                                  
│       - 0.9                                                                   
│       - 0.999                                                                 
│                                                                               
├── scheduler
│   └── _name_: constant_warmup                                                 
│       num_warmup_steps: 5000                                                  
│                                                                               
├── loader
│   └── batch_size: 4                                                           
│       num_workers: 4                                                          
│       pin_memory: true                                                        
│       drop_last: true                                                         
│                                                                               
├── dataset
│   └── _name_: pathfinder_lm                                                   
│       resolution: 128                                                         
│       sequential: true                                                        
│       val_split: 0.1                                                          
│       test_split: 0.1                                                         
│       mlm_prob: 0.0                                                           
│       ignore_val: true                                                        
│       causal_lm: true                                                         
│       lm_loss: ce                                                             
│       span_masking: false                                                     
│       span_length: 1                                                          
│       seed: 42                                                                
│                                                                               
├── task
│   └── _name_: base                                                            
│       loss: cross_entropy                                                     
│       metrics:                                                                
│       - accuracy_ignore_m100                                                  
│       torchmetrics: null                                                      
│                                                                               
├── encoder
│   └── linear                                                                  
├── decoder
│   └── _name_: sequence                                                        
│       mode: last                                                              
│                                                                               
├── model
│   └── layer:                                                                  
│       - _name_: mhfla                                                         
│         causal: false                                                         
│         window_size: 4096                                                     
│         look_forward: 1                                                       
│         look_backward: 1                                                      
│         n_heads: 4                                                            
│         dropout: null                                                         
│         bias: true                                                            
│         add_bias_kv: false                                                    
│         add_zero_attn: false                                                  
│         kdim: null                                                            
│         vdim: null                                                            
│         rotary: false                                                         
│       - _name_: ff                                                            
│         expand: 1                                                             
│         activation: gelu                                                      
│         dropout: 0.0                                                          
│       _name_: model                                                           
│       prenorm: true                                                           
│       transposed: false                                                       
│       n_layers: 4                                                             
│       d_model: 128                                                            
│       residual: R                                                             
│       pool:                                                                   
│         _name_: pool                                                          
│         stride: 1                                                             
│         expand: null                                                          
│       norm: layer                                                             
│       dropout: 0.0                                                            
│       tie_dropout: false                                                      
│       track_norms: true                                                       
│       encoder:                                                                
│         _name_: position                                                      
│         dropout: 0.0                                                          
│       decoder: null                                                           
│                                                                               
└── callbacks
    └── learning_rate_monitor:                                                  
          logging_interval: step                                                
        timer:                                                                  
          step: true                                                            
          inter_step: false                                                     
          epoch: true                                                           
          val: true                                                             
        params:                                                                 
          total: true                                                           
          trainable: true                                                       
          fixed: true                                                           
        model_checkpoint:                                                       
          monitor: val/accuracy_ignore_m100                                     
          mode: max                                                             
          save_top_k: 1                                                         
          save_last: true                                                       
          dirpath: checkpoints/                                                 
          filename: val/accuracy_ignore_m100                                    
          auto_insert_metric_name: false                                        
          verbose: true    

The curves during training are as follows:

Train loss: image Performance on val set: image

After training 6 epochs, the accuracy of next token prediction reached 99% and we considered it converged.

Then, we fine-tuned the pretrained model above on pathx. The configuration is as follow:

CONFIG
├── train
│   └── seed: 2222                                                              
│       interval: step                                                          
│       monitor: val/accuracy                                                   
│       mode: max                                                               
│       ema: 0.0                                                                
│       test: false                                                             
│       debug: false                                                            
│       ignore_warnings: false                                                  
│       state:                                                                  
│         mode: null                                                            
│         n_context: 0                                                          
│         n_context_eval: 0                                                     
│       ckpt: null                                                              
│       disable_dataset: false                                                  
│       validate_at_start: false                                                
│       pretrained_model_path: /my/path/to/pretrained_model.ckpt 
│       pretrained_model_strict_load: true                                      
│       pretrained_model_state_hook:                                            
│         _name_: null                                                          
│       post_init_hook:                                                         
│         _name_: null                                                          
│       layer_decay:                                                            
│         _name_: null                                                          
│         decay: 0.7                                                            
│       manual_checkpoints: true                                                
│       ckpt_milestones: null                                                   
│       log_tensors_freq: -1                                                    
│                                                                               
├── tolerance
│   └── logdir: ./resume                                                        
│       id: null                                                                
│                                                                               
├── wandb
│   └── project: state-spaces                                                   
│       group: ''                                                               
│       job_type: training                                                      
│       mode: online                                                            
│       save_dir: null                                                          
│       id: null                                                                
│       entity: ''                                                              
│       name: ''                                                                
│                                                                               
├── trainer
│   └── gpus: 8                                                                 
│       accumulate_grad_batches: 1                                              
│       max_epochs: 200                                                         
│       gradient_clip_val: 0.0                                                  
│       log_every_n_steps: 10                                                   
│       limit_train_batches: 1.0                                                
│       limit_val_batches: 1.0                                                  
│       weights_summary: top                                                    
│       progress_bar_refresh_rate: 1                                            
│       track_grad_norm: -1                                                     
│                                                                               
├── loader
│   └── batch_size: 4                                                           
│       num_workers: 4                                                          
│       pin_memory: true                                                        
│       drop_last: true                                                         
│                                                                               
├── dataset
│   └── _name_: pathfinder                                                      
│       resolution: 128                                                         
│       sequential: true                                                        
│       tokenize: false                                                         
│       pool: 1                                                                 
│       val_split: 0.1                                                          
│       test_split: 0.1                                                         
│       seed: 42                                                                
│                                                                               
├── task
│   └── _name_: base                                                            
│       loss: cross_entropy                                                     
│       metrics:                                                                
│       - accuracy                                                              
│       torchmetrics: null                                                      
│                                                                               
├── optimizer
│   └── _name_: adamw                                                           
│       lr: 0.0005                                                              
│       weight_decay: 0.0                                                       
│       betas:                                                                  
│       - 0.9                                                                   
│       - 0.999                                                                 
│                                                                               
├── scheduler
│   └── _name_: constant_warmup                                                 
│       num_warmup_steps: 5000                                                  
│                                                                               
├── encoder
│   └── linear                                                                  
├── decoder
│   └── _name_: sequence                                                        
│       mode: max                                                               
│                                                                               
├── model
│   └── layer:                                                                  
│       - _name_: mhfla                                                         
│         causal: false                                                         
│         window_size: 4096                                                     
│         look_forward: 1                                                       
│         look_backward: 1                                                      
│         n_heads: 4                                                            
│         dropout: null                                                         
│         bias: true                                                            
│         add_bias_kv: false                                                    
│         add_zero_attn: false                                                  
│         kdim: null                                                            
│         vdim: null                                                            
│         rotary: false                                                         
│       - _name_: ff                                                            
│         expand: 1                                                             
│         activation: gelu                                                      
│         dropout: 0.0                                                          
│       _name_: model                                                           
│       prenorm: true                                                           
│       transposed: false                                                       
│       n_layers: 4                                                             
│       d_model: 128                                                            
│       residual: R                                                             
│       pool:                                                                   
│         _name_: pool                                                          
│         stride: 1                                                             
│         expand: null                                                          
│       norm: layer                                                             
│       dropout: 0.0                                                            
│       tie_dropout: false                                                      
│       track_norms: true                                                       
│       encoder:                                                                
│         _name_: position                                                      
│         dropout: 0.0                                                          
│       decoder: null                                                           
│                                                                               
└── callbacks
    └── learning_rate_monitor:                                                  
          logging_interval: step                                                
        timer:                                                                  
          step: true                                                            
          inter_step: false                                                     
          epoch: true                                                           
          val: true                                                             
        params:                                                                 
          total: true                                                           
          trainable: true                                                       
          fixed: true                                                           
        model_checkpoint:                                                       
          monitor: val/accuracy                                                 
          mode: max                                                             
          save_top_k: 1                                                         
          save_last: true                                                       
          dirpath: checkpoints/                                                 
          filename: val/accuracy                                                
          auto_insert_metric_name: false                                        
          verbose: true                                                      

The curves during training are as follows: Train loss: image Performance on val set: image

In the FT phase, the loss didn't decrease and the model was maintained at randomized guesses.


We would like to know if these training curves are correct. Could you provide detailed parameters that reproduce the results of this experiment in the paper, including the number of training epochs in the SPT and FT phases?

IdoAmos commented 7 months ago

Hi! very glad to hear you are using our repo and experimenting with SPT. With SPT you should see an increase in performance very early in training if your hyperparameters are tuned, you can check out the ones we used in the paper, appendix C.1.

For the issue mentioned, it's hard to see from just looking at the config file, could help if you can share the run command as well. If I understand correctly, you are pretraining with a next token prediction objective. If that is the case, please note that you are allowing the model to attend both backward and forward windows (i.e future tokens), which we didn't try, I'm guessing this leads to less effective pretraining. For PathX the default config uses chunked attention due to the long sequence length. For causal training you need to specify the mode explicitly and the windows to attend by: model.layer.causal=true model.layer.look_forward=0 model.layer.look_backward=<num_windows> In the experiments listed in the paper I always used <num_windows>=1.

I hope this solves the issue, please let me know otherwise.