princeton-nlp / MeZO

[NeurIPS 2023] MeZO: Fine-Tuning Language Models with Just Forward Passes. https://arxiv.org/abs/2305.17333
MIT License
1.04k stars 63 forks source link

MeZo can be used in NLG tasks? #4

Open anonNo2 opened 1 year ago

anonNo2 commented 1 year ago

Can MeZo be used on NLG tasks? I integrated the _inner_training_loop part of the code and the methods it relies on into the NLG task training code, and performed fine-tuning training on bloom (bloomz-1b), and found that the effect was relatively poor. could you provide some guidance ?

gaotianyu1350 commented 1 year ago

Hi,

MeZO works on NLG tasks (see SQuAD as an example). We'll provide a more detailed readme very soon!

anonNo2 commented 1 year ago

Thank you very much for your reply !

I tried to use Bloomz-1b to run SQuAD2 on your open source code, the following is my hyperparameter.

task_name: str = "SQuADv2" 
num_train: int = 10000 
num_dev: int = None 
num_eval: int = None
num_train_sets: int = 1  
train_set_seed: int = None  
result_file: str = None 

model_name: str = "bigscience/bloomz-1b1"
load_float16: bool = True 
load_bfloat16: bool = False
load_int8: bool = False

max_length: int = 2048
no_auto_device: bool = False
# calibration
sfc: bool = False
icl_sfc: bool = False

# training
trainer: str = "zo"
## options
## - none: no training. only use in-context learning or zero-shot.
## - regular: regular trainer
## - zo: zeroth-order training
only_train_option: bool = False
train_as_classification: bool = False

zo_inplace: bool = False  
zo_eps: float = 1e-3  
zo_torch_optim: bool = False 
zo_sample_scheduler: str = None  
zo_sample: int = 1 
zo_clip_grad: float = None  
zo_scale_lr_with_samples: bool = None
zo_pc: bool = False 
zo_pc_recompute: bool = False  
zo_pc_split_by_emb: bool = False  
zo_pc_w_zo_estimate: bool = False 
zo_pc_use_norm: bool = False 
zo_pc_scale_by_num_params: bool = False 
zo_pc_rnd_layers: bool = False  
zo_layer_wise_optim: bool = False  

# prefix tuning
prefix_tuning: bool = False 
num_prefix: int = 5 
no_reparam: bool = False
prefix_init_by_real_act: bool = False

# lora
lora: bool = False 
lora_alpha: int = 16
lora_r: int = 8

# generation
sampling: bool = False
temperature: float = 1.0 
num_beams: int = 1 
top_k: int = None 
top_p: float = 0.95 
max_new_tokens: int = 50  
eos_token: str = "\n" 

# saving
save_model: bool = False
no_eval: bool = False
tag: str = ""

# linear probing
linear_probing: bool = False
lp_early_stopping: bool = False

# head-tuning
head_tuning: bool = False

# untie emb/lm_head weights
untie_emb: bool = False

# display
verbose: bool = False

# non-diff objective
non_diff: bool = False

At the same time, I set --logging_steps to 50 , and I found a strange phenomenon that Loss will remain at 0 after experiencing the initial ups and downs. I don't know if there is a problem with my hyperparameter settings. I haven't modified the code logic. The following is my log print

{'loss': 74.8421, 'learning_rate': 4.933333333333334e-05, 'epoch': 0.04}                                              
{'loss': 240.675, 'learning_rate': 4.866666666666667e-05, 'epoch': 0.08}                                              
{'loss': 639.575, 'learning_rate': 4.8e-05, 'epoch': 0.12}                                                            
{'loss': 967.742, 'learning_rate': 4.7333333333333336e-05, 'epoch': 0.16}                                             
{'loss': 0.0, 'learning_rate': 4.666666666666667e-05, 'epoch': 0.2}                                                   
{'loss': 0.0, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.24}                                                  
{'loss': 0.0, 'learning_rate': 4.5333333333333335e-05, 'epoch': 0.28}                                                 
{'loss': 0.0, 'learning_rate': 4.466666666666667e-05, 'epoch': 0.32}                                                  
{'loss': 0.0, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.36}                                                 
{'loss': 0.0, 'learning_rate': 4.3333333333333334e-05, 'epoch': 0.4}                                                  
{'loss': 0.0, 'learning_rate': 4.266666666666667e-05, 'epoch': 0.44}                                                  
{'loss': 0.0, 'learning_rate': 4.2e-05, 'epoch': 0.48}                                                                
{'loss': 0.0, 'learning_rate': 4.133333333333333e-05, 'epoch': 0.52}                                                  
{'loss': 0.0, 'learning_rate': 4.066666666666667e-05, 'epoch': 0.56}                                                  
{'loss': 0.0, 'learning_rate': 4e-05, 'epoch': 0.6}                                                                   
{'loss': 0.0, 'learning_rate': 3.933333333333333e-05, 'epoch': 0.64}                                                  
{'loss': 0.0, 'learning_rate': 3.866666666666667e-05, 'epoch': 0.68}                                                  
{'loss': 0.0, 'learning_rate': 3.8e-05, 'epoch': 0.72}                                                                
{'loss': 0.0, 'learning_rate': 3.733333333333334e-05, 'epoch': 0.76}                                                  
{'loss': 0.0, 'learning_rate': 3.6666666666666666e-05, 'epoch': 0.8}                                                  
{'loss': 0.0, 'learning_rate': 3.6e-05, 'epoch': 0.84}                                                                
{'loss': 0.0, 'learning_rate': 3.5333333333333336e-05, 'epoch': 0.88}                                                 
{'loss': 0.0, 'learning_rate': 3.466666666666667e-05, 'epoch': 0.92}                                                  
{'loss': 0.0, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.96}                                                 
{'loss': 0.0, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}                                                  
{'loss': 0.0, 'learning_rate': 3.266666666666667e-05, 'epoch': 1.04}                                                  
{'loss': 0.0, 'learning_rate': 3.2000000000000005e-05, 'epoch': 1.08}                                                 
{'loss': 0.0, 'learning_rate': 3.1333333333333334e-05, 'epoch': 1.12}                                                 
{'loss': 0.0, 'learning_rate': 3.066666666666667e-05, 'epoch': 1.16}                                                  
{'loss': 0.0, 'learning_rate': 3e-05, 'epoch': 1.2}                                                                   
{'loss': 0.0, 'learning_rate': 2.9333333333333336e-05, 'epoch': 1.24}                                                 
{'loss': 0.0, 'learning_rate': 2.8666666666666668e-05, 'epoch': 1.28}                                                 
{'loss': 0.0, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.32}                                                 
{'loss': 0.0, 'learning_rate': 2.733333333333333e-05, 'epoch': 1.36}                                                  
{'loss': 0.0, 'learning_rate': 2.6666666666666667e-05, 'epoch': 1.4}                                                  
{'loss': 0.0, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.44}                                                 
{'loss': 0.0, 'learning_rate': 2.5333333333333337e-05, 'epoch': 1.48}                                                 
{'loss': 0.0, 'learning_rate': 2.466666666666667e-05, 'epoch': 1.52}                                                  
{'loss': 0.0, 'learning_rate': 2.4e-05, 'epoch': 1.56}                                                                
{'loss': 0.0, 'learning_rate': 2.3333333333333336e-05, 'epoch': 1.6}                                                  
{'loss': 0.0, 'learning_rate': 2.2666666666666668e-05, 'epoch': 1.64}                                                 
{'loss': 0.0, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.68}                                                 
{'loss': 0.0, 'learning_rate': 2.1333333333333335e-05, 'epoch': 1.72}                                                 
{'loss': 0.0, 'learning_rate': 2.0666666666666666e-05, 'epoch': 1.76}                                                 
{'loss': 0.0, 'learning_rate': 2e-05, 'epoch': 1.8}                                                                   
{'loss': 0.0, 'learning_rate': 1.9333333333333333e-05, 'epoch': 1.84}                                                 
{'loss': 0.0, 'learning_rate': 1.866666666666667e-05, 'epoch': 1.88}                                                  
{'loss': 0.0, 'learning_rate': 1.8e-05, 'epoch': 1.92}                                                                
{'loss': 0.0, 'learning_rate': 1.7333333333333336e-05, 'epoch': 1.96}                                                 
{'loss': 0.0, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}                                                  
{'loss': 0.0, 'learning_rate': 1.6000000000000003e-05, 'epoch': 2.04}                                                 
{'loss': 0.0, 'learning_rate': 1.5333333333333334e-05, 'epoch': 2.08}                                                 
{'loss': 0.0, 'learning_rate': 1.4666666666666668e-05, 'epoch': 2.12}                                                 
{'loss': 0.0, 'learning_rate': 1.4000000000000001e-05, 'epoch': 2.16}                                                 
{'loss': 0.0, 'learning_rate': 1.3333333333333333e-05, 'epoch': 2.2}                                                  
{'loss': 0.0, 'learning_rate': 1.2666666666666668e-05, 'epoch': 2.24}                                                 
{'loss': 0.0, 'learning_rate': 1.2e-05, 'epoch': 2.28}                                                                
{'loss': 0.0, 'learning_rate': 1.1333333333333334e-05, 'epoch': 2.32}                                                 
{'loss': 0.0, 'learning_rate': 1.0666666666666667e-05, 'epoch': 2.36}                                                 
{'loss': 0.0, 'learning_rate': 1e-05, 'epoch': 2.4}                                                                   
{'loss': 0.0, 'learning_rate': 9.333333333333334e-06, 'epoch': 2.44}                                                  
{'loss': 0.0, 'learning_rate': 8.666666666666668e-06, 'epoch': 2.48}                                                  
{'loss': 0.0, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.52}                                                  
{'loss': 0.0, 'learning_rate': 7.333333333333334e-06, 'epoch': 2.56}                                                  
{'loss': 0.0, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.6}                                                   
{'loss': 0.0, 'learning_rate': 6e-06, 'epoch': 2.64}                                                                  
{'loss': 0.0, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.68}                                                  
{'loss': 0.0, 'learning_rate': 4.666666666666667e-06, 'epoch': 2.72}                                                  
{'loss': 0.0, 'learning_rate': 4.000000000000001e-06, 'epoch': 2.76}                                                  
{'loss': 0.0, 'learning_rate': 3.3333333333333333e-06, 'epoch': 2.8}                                                  
{'loss': 0.0, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.84}                                                  
{'loss': 0.0, 'learning_rate': 2.0000000000000003e-06, 'epoch': 2.88}                                                 
{'loss': 0.0, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.92}                                                 
{'loss': 0.0, 'learning_rate': 6.666666666666667e-07, 'epoch': 2.96}                                                  
{'loss': 0.0, 'learning_rate': 0.0, 'epoch': 3.0}  

Could you give me some guidance, I really really think MeZo is cool 🤔

anonNo2 commented 1 year ago

But I didn't change any hyperparameters, just changed the model used to facebook/opt-125m, it seems that the loss can drop smoothly. 😂

{'loss': 3.7535, 'learning_rate': 4.933333333333334e-05, 'epoch': 0.04}                                               
{'loss': 7.4188, 'learning_rate': 4.866666666666667e-05, 'epoch': 0.08}                                               
{'loss': 8.3094, 'learning_rate': 4.8e-05, 'epoch': 0.12}                                                             
{'loss': 8.3594, 'learning_rate': 4.7333333333333336e-05, 'epoch': 0.16}                                              
{'loss': 8.2979, 'learning_rate': 4.666666666666667e-05, 'epoch': 0.2}                                                
{'loss': 8.2892, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.24}                                               
{'loss': 8.2078, 'learning_rate': 4.5333333333333335e-05, 'epoch': 0.28}                                              
{'loss': 8.2466, 'learning_rate': 4.466666666666667e-05, 'epoch': 0.32}                                               
{'loss': 8.206, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.36}                                               
{'loss': 8.1586, 'learning_rate': 4.3333333333333334e-05, 'epoch': 0.4}                                               
{'loss': 8.1869, 'learning_rate': 4.266666666666667e-05, 'epoch': 0.44}                                               
{'loss': 8.1488, 'learning_rate': 4.2e-05, 'epoch': 0.48}                                                             
{'loss': 8.1909, 'learning_rate': 4.133333333333333e-05, 'epoch': 0.52}                                               
{'loss': 8.1738, 'learning_rate': 4.066666666666667e-05, 'epoch': 0.56}                                               
{'loss': 8.1662, 'learning_rate': 4e-05, 'epoch': 0.6}                                                                
{'loss': 8.1148, 'learning_rate': 3.933333333333333e-05, 'epoch': 0.64}                                               
{'loss': 8.1228, 'learning_rate': 3.866666666666667e-05, 'epoch': 0.68}                                               
{'loss': 8.1038, 'learning_rate': 3.8e-05, 'epoch': 0.72}                                                             
{'loss': 8.1227, 'learning_rate': 3.733333333333334e-05, 'epoch': 0.76}                                               
{'loss': 8.1639, 'learning_rate': 3.6666666666666666e-05, 'epoch': 0.8}                                               
{'loss': 8.12, 'learning_rate': 3.6e-05, 'epoch': 0.84}                                                               
{'loss': 8.1024, 'learning_rate': 3.5333333333333336e-05, 'epoch': 0.88}                                              
{'loss': 8.1228, 'learning_rate': 3.466666666666667e-05, 'epoch': 0.92}                                               
{'loss': 8.1113, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.96}                                              
{'loss': 8.0706, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}                                               
{'loss': 8.0874, 'learning_rate': 3.266666666666667e-05, 'epoch': 1.04}                                               
{'loss': 8.0905, 'learning_rate': 3.2000000000000005e-05, 'epoch': 1.08}                                              
{'loss': 8.0841, 'learning_rate': 3.1333333333333334e-05, 'epoch': 1.12}                                              
{'loss': 8.1071, 'learning_rate': 3.066666666666667e-05, 'epoch': 1.16}                                               
{'loss': 8.058, 'learning_rate': 3e-05, 'epoch': 1.2}                                                                 
{'loss': 8.0488, 'learning_rate': 2.9333333333333336e-05, 'epoch': 1.24}                                              
{'loss': 8.0453, 'learning_rate': 2.8666666666666668e-05, 'epoch': 1.28}                                              
{'loss': 8.047, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.32}                                               
{'loss': 8.0246, 'learning_rate': 2.733333333333333e-05, 'epoch': 1.36}                                               
{'loss': 7.9912, 'learning_rate': 2.6666666666666667e-05, 'epoch': 1.4}                                               
{'loss': 8.0264, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.44}                                              
{'loss': 7.982, 'learning_rate': 2.5333333333333337e-05, 'epoch': 1.48}                                               
{'loss': 7.9686, 'learning_rate': 2.466666666666667e-05, 'epoch': 1.52}                                               
{'loss': 7.9592, 'learning_rate': 2.4e-05, 'epoch': 1.56}                                                             
{'loss': 7.9862, 'learning_rate': 2.3333333333333336e-05, 'epoch': 1.6}                                               
{'loss': 7.9894, 'learning_rate': 2.2666666666666668e-05, 'epoch': 1.64}                                              
{'loss': 7.9723, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.68}                                              
{'loss': 7.9499, 'learning_rate': 2.1333333333333335e-05, 'epoch': 1.72}                                              
{'loss': 7.9565, 'learning_rate': 2.0666666666666666e-05, 'epoch': 1.76}                                              
{'loss': 7.938, 'learning_rate': 2e-05, 'epoch': 1.8}                                                                 
{'loss': 7.9411, 'learning_rate': 1.9333333333333333e-05, 'epoch': 1.84}                                              
{'loss': 7.9223, 'learning_rate': 1.866666666666667e-05, 'epoch': 1.88}                                               
{'loss': 7.9638, 'learning_rate': 1.8e-05, 'epoch': 1.92}                                                             
{'loss': 7.9739, 'learning_rate': 1.7333333333333336e-05, 'epoch': 1.96}                                              
{'loss': 7.9316, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}                                               
{'loss': 7.9279, 'learning_rate': 1.6000000000000003e-05, 'epoch': 2.04}                                              
{'loss': 7.9124, 'learning_rate': 1.5333333333333334e-05, 'epoch': 2.08}                                              
{'loss': 7.91, 'learning_rate': 1.4666666666666668e-05, 'epoch': 2.12}                                                
{'loss': 7.908, 'learning_rate': 1.4000000000000001e-05, 'epoch': 2.16}                                               
{'loss': 7.9101, 'learning_rate': 1.3333333333333333e-05, 'epoch': 2.2}                                               
{'loss': 7.9208, 'learning_rate': 1.2666666666666668e-05, 'epoch': 2.24}                                              
{'loss': 7.9161, 'learning_rate': 1.2e-05, 'epoch': 2.28}                                                             
{'loss': 7.9043, 'learning_rate': 1.1333333333333334e-05, 'epoch': 2.32}                                              
{'loss': 7.9158, 'learning_rate': 1.0666666666666667e-05, 'epoch': 2.36}                                              
{'loss': 7.8968, 'learning_rate': 1e-05, 'epoch': 2.4}                                                                
{'loss': 7.8973, 'learning_rate': 9.333333333333334e-06, 'epoch': 2.44}                                               
{'loss': 7.8825, 'learning_rate': 8.666666666666668e-06, 'epoch': 2.48}                                               
{'loss': 7.8814, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.52}                                               
{'loss': 7.8905, 'learning_rate': 7.333333333333334e-06, 'epoch': 2.56}                                               
{'loss': 7.8777, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.6}                                                
{'loss': 7.8753, 'learning_rate': 6e-06, 'epoch': 2.64}                                                               
{'loss': 7.8527, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.68}                                               
{'loss': 7.8781, 'learning_rate': 4.666666666666667e-06, 'epoch': 2.72}                                               
{'loss': 7.8937, 'learning_rate': 4.000000000000001e-06, 'epoch': 2.76}                                               
{'loss': 7.8705, 'learning_rate': 3.3333333333333333e-06, 'epoch': 2.8}                                               
{'loss': 7.8792, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.84}                                               
{'loss': 7.8713, 'learning_rate': 2.0000000000000003e-06, 'epoch': 2.88}                                              
{'loss': 7.8866, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.92}                                              
{'loss': 7.8813, 'learning_rate': 6.666666666666667e-07, 'epoch': 2.96}                                               
{'loss': 7.8713, 'learning_rate': 0.0, 'epoch': 3.0}  
gaotianyu1350 commented 1 year ago

Hi,

Thanks for your interest in our paper! We just updated the README and there is an example for non-differentiable training. Maybe try out the new example code?

lramming commented 1 year ago

At the same time, I set --logging_steps to 50 , and I found a strange phenomenon that Loss will remain at 0 after experiencing the initial ups and downs. I don't know if there is a problem with my hyperparameter settings. I haven't modified the code logic. The following is my log print

I think a loss of 0 is an indicator for an unreported over/underflow that leads to the loss not being computed correctly. I would advise to reduce the learning rate. Intuitively reason is that the 1B model has ~10x the number of parameters of the opt-125m model, so the loss landscape is much more complex. If you keep the lr high, then it can not navigate through the loss landscape to find the minimum, so you need lower learning rates with larger models.