princeton-nlp / MeZO

[NeurIPS 2023] MeZO: Fine-Tuning Language Models with Just Forward Passes. https://arxiv.org/abs/2305.17333
MIT License
1.02k stars 60 forks source link

Can you share the dataset class of SST-5, SNLI, TREC datasets? #36

Open Ziiiirem opened 4 weeks ago

Ziiiirem commented 4 weeks ago

Hi, i am interested in your non-differentiable objectives experiments using MeZO, but i don't find the dataset class and prompt template of SST-5, SNLI, TREC datasets. Can you share the dataset class of SST-5, SNLI, TREC datasets? Thank you very much!!

Ziiiirem commented 4 weeks ago

Also, I tried to modify the code to support the zero-order optimization training of accuracy, an non-differentiable objective function. I use roberta-large model and SST2 dataset. I set the batchsize to 512 and the learning rate to 1e-6 and 5e-7. I tried to reproduce the results in your paper, but the training results were poor. Can you share this part of your code implementation?

gaotianyu1350 commented 4 weeks ago

Hi,

You can run the non-differentiable example by (large models, squad, also mentioned in README)

MODEL=facebook/opt-13b TASK=SQuAD MODE=prefix LR=1e-2 EPS=1e-1 bash mezo.sh --non_diff --evaluation_strategy no --save_strategy no --save_model

The implementation is here: https://github.com/princeton-nlp/MeZO/blob/552cb1b710767f9a6e1dc8f9645d7640376f9941/large_models/trainer.py#L734

Ziiiirem commented 4 weeks ago

thanks for your reply! Sure, i have already tried the OPT-13b model finetuning on Squad dataset using MeZO, and the result is quite good. I want to try more non-differentiable example, such as Classfication tasks (accuracy metric), Can you share this part of your code implementation? I really appreciate your help.

gaotianyu1350 commented 2 weeks ago

Hi Ziming,

I realized the feature is actually provided. It is implemented under the flag --optimize_acc in the medium sized model folder.

Ziiiirem commented 2 weeks ago

Yes, I have resolved my issue, and I am very grateful for your enthusiastic assistance!!!