princeton-nlp / SimPO

[NeurIPS 2024] SimPO: Simple Preference Optimization with a Reference-Free Reward
MIT License
711 stars 49 forks source link

can't reproduce the results of Mistral-7B-Instruct-DPO #38

Closed RikkiXu closed 3 months ago

RikkiXu commented 4 months ago

hi, can I ask about the beta and leaning rate of Mistral-7B-Instruct-DPO? I can't reproduce the results in the paper.

xiamengzhou commented 4 months ago

@RikkiXu Please check out the hyperparameters here: https://github.com/princeton-nlp/SimPO/blob/main/README.md#hyperparameter-tuning

RikkiXu commented 4 months ago

hi @xiamengzhou Thanks for your update, but I still can't reproduce it successfully. I used alignment-handbook, configs are as follows, alpaca-eval is only about 18, is there any problem? I am experimenting on 8 A100s

model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2 torch_dtype: null dataset_mixer: princeton-nlp/mistral-instruct-ultrafeedback: 1.0 dataset_splits:

xiamengzhou commented 4 months ago

@RikkiXu It seems that you are using max_length: 1024, max_prompt_length: 512 in your script. We use 2048 and 1800 respectively in our script. Hope it helps address the issue!

RikkiXu commented 4 months ago

Hi @xiamengzhou Thanks for your reply. I reproduced Mistral-7B-Instruct-SimPO and Mistral-7B-Instruct-DPO strictly according to the latest repo. Unfortunately, neither of them achieved the effect of the paper.

For Mistral-7B-Instruct-SimPO, the only change I made was to change 4 gpts to 8 gpts, and set gradient_accumulation_steps to 8, so that batch size=128. The final alpacaeval2 result of Mistral-7B-Instruct-SimPO is as follows:
 
length_controlled_winrate win_rate standard_error n_total avg_length
Mistral-7B-Instruct-simpo 27.55 29.28 1.32 805 2352

For Mistral-7B-Instruct-DPO, the only change I made was to replace 251-258 of run_simpo.py with:

from trl import DPOTrainer
ref_model = model
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
The result of Mistral-7B-Instruct-DPO is as follows:
 
length_controlled_winrate win_rate standard_error n_total avg_length
Mistral-7B-Instruct-dpo 21.51 17.70 1.15 805 1536

When I directly eval the checkpoint you published, I can get alpacaeval2 similar to the paper, which are 32 and 26 respectively. But when I trained myself, Mistral-7B-Instruct-DPO and Mistral-7B-Instruct-SimPO both dropped 5 points in Alpacaeval2. This situation only occurred when reproducing Mistral-7B-Instruct, and I was able to successfully reproduce Mistral-7B-Base. Could you please check it? Thank you very much for your reply again.

AGTSAAA commented 4 months ago

Hi, @RikkiXu,

Could you please share your version of packages (pip list)? I can not even successfully reproduce Mistral-7B-Base

Thank you very much.

RikkiXu commented 4 months ago

Hi, @RikkiXu,

Could you please share your version of packages (pip list)? I can not even successfully reproduce Mistral-7B-Base

Thank you very much.

Package Version


absl-py 2.1.0 accelerate 0.26.1 aiofiles 22.1.0 aiohttp 3.9.5 aiosignal 1.3.1 aiosqlite 0.20.0 alabaster 0.7.16 alignment-handbook 0.4.0.dev0 alpaca_eval 0.6.3 annotated-types 0.7.0 anyio 4.4.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 async-timeout 4.0.3 attrs 23.1.0 Babel 2.15.0 beautifulsoup4 4.12.3 bidict 0.23.1 bitsandbytes 0.42.0 bleach 6.1.0 byted_remote_ikernel 0.4.8 byted-torch 2.1.0.post2 byted_torch_monitor 0.0.1 byted-wandb 0.13.72 bytedance-context 0.7.1 bytedance-metrics 0.5.1 bytedance.modelhub 0.0.70 bytedance.servicediscovery 0.1.2 bytedbackgrounds 0.0.6 byteddatabus 1.0.6 byteddps 0.1.2 bytedenv 0.6.2 bytedlogger 0.15.1 bytedmemfd 0.2 bytedmetrics 0.10.2 bytedpymongo 2.0.5 bytedrh2 1.18.9a11 bytedservicediscovery 0.17.4 bytedtcc 1.4.2 bytedtos 1.1.16 bytedtrace 0.3.0 bytedztijwthelper 0.0.23 bytedztispiffe 0.0.14 certifi 2024.6.2 cffi 1.16.0 chardet 5.2.0 charset-normalizer 3.3.2 click 8.1.7 cloudpickle 3.0.0 cmake 3.30.0 comm 0.2.2 crcmod 1.7 cryptography 42.0.8 Cython 3.0.10 datasets 2.16.1 dbus-python 1.2.16 debugpy 1.8.2 decorator 5.1.1 deepspeed 0.12.2 defusedxml 0.7.1 Deprecated 1.2.14 dill 0.3.7 diskcache 5.6.3 distlib 0.3.8 distro 1.9.0 distro-info 1.0+deb11u1 dnspython 2.6.1 docker-pycreds 0.4.0 docstring_parser 0.16 docutils 0.19 einops 0.8.0 email_validator 2.2.0 entrypoints 0.4 eval_type_backport 0.2.0 evaluate 0.4.1 exceptiongroup 1.2.1 executing 2.0.1 fastapi 0.111.0 fastapi-cli 0.0.4 fastjsonschema 2.20.0 filelock 3.15.1 fire 0.6.0 flash-attn 2.5.6 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2023.10.0 gitdb 4.0.11 GitPython 3.1.43 grpcio 1.64.1 h11 0.14.0 hf_transfer 0.1.6 hjson 3.1.0 httpcore 1.0.5 httptools 0.6.1 httpx 0.27.0 huggingface-hub 0.23.4 idna 3.7 imagesize 1.4.1 importlib_metadata 7.1.0 iniconfig 2.0.0 interegular 0.3.3 iotop 0.6 ipaddress 1.0.23 ipykernel 6.29.5 ipython 8.18.1 ipython-genutils 0.2.0 ipywidgets 8.1.3 isoduration 20.11.0 jedi 0.19.1 Jinja2 3.1.4 joblib 1.4.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter-client 7.0.0 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter_server 2.14.1 jupyter_server_fileid 0.9.2 jupyter_server_terminals 0.5.3 jupyter_server_ydoc 0.8.0 jupyter-ydoc 0.2.5 jupyterlab 3.6.4 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.2 jupyterlab_widgets 3.0.11 lark 1.1.9 llvmlite 0.43.0 Markdown 3.6 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib-inline 0.1.7 mdurl 0.1.2 mistune 3.0.2 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.15 nbclassic 1.1.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.2.1 ninja 1.11.1.1 none 0.1.1 notebook 6.5.7 notebook_shim 0.2.4 numba 0.60.0 numpy 1.26.4 openai 1.35.13 orjson 3.10.6 outlines 0.0.34 overrides 7.7.0 packaging 24.1 pandas 2.2.2 pandocfilters 1.5.1 parso 0.8.4 pathlib2 2.3.7.post1 pathtools 0.1.2 patsy 0.5.6 peft 0.8.2 pexpect 4.8.0 pillow 10.2.0 pip 24.1.2 platformdirs 4.2.2 pluggy 1.5.0 ply 3.11 prometheus_client 0.20.0 promise 2.3 prompt_toolkit 3.0.47 protobuf 3.20.2 psutil 6.0.0 ptyprocess 0.7.0 pure-eval 0.2.2 py 1.11.0 py-cpuinfo 9.0.0 py-spy 0.3.14 pyarrow 16.1.0 pyarrow-hotfix 0.6 pycparser 2.22 pycryptodomex 3.20.0 pycurl 7.43.0.6 pydantic 2.8.2 pydantic_core 2.20.1 Pygments 2.18.0 PyGObject 3.38.0 PyJWT 2.8.0 pynvml 11.5.0 pyOpenSSL 24.1.0 pytest 6.2.5 python-apt 2.2.1 python-consul 1.1.0 python-dateutil 2.9.0.post0 python-dotenv 1.0.1 python-engineio 4.9.1 python-etcd 0.4.5 python-json-logger 2.0.7 python-multipart 0.0.9 python-socketio 5.11.3 pytz 2024.1 PyYAML 6.0.1 pyzmq 26.0.3 qtconsole 5.5.2 QtPy 2.4.1 ray 2.32.0 referencing 0.35.1 regex 2024.5.15 requests 2.32.3 responses 0.18.0 rfc3339-validator 0.1.4 rfc3986 2.0.0 rfc3986-validator 0.1.1 rich 13.7.1 rpds-py 0.19.0 safetensors 0.4.3 schedule 1.2.2 scikit-learn 1.5.1 scipy 1.13.1 Send2Trash 1.8.3 sentencepiece 0.2.0 sentry-sdk 2.9.0 setproctitle 1.3.3 setuptools 69.5.1 shellingham 1.5.4 shortuuid 1.0.13 shtab 1.7.1 simple-websocket 1.0.0 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 snowballstemmer 2.2.0 soupsieve 2.5 Sphinx 5.3.0 sphinxcontrib-applehelp 1.0.8 sphinxcontrib-devhelp 1.0.6 sphinxcontrib-htmlhelp 2.0.5 sphinxcontrib-jsmath 1.0.1 sphinxcontrib-qthelp 1.0.7 sphinxcontrib-serializinghtml 1.1.10 sphinxcontrib-websupport 1.2.7 stack-data 0.6.3 starlette 0.37.2 sympy 1.12.1 tensorboard 2.17.0 tensorboard-data-server 0.7.2 termcolor 2.4.0 terminado 0.18.1 threadpoolctl 3.5.0 thriftpy2 0.5.2 tiktoken 0.6.0 tinycss2 1.3.0 tokenizers 0.19.1 toml 0.10.2 tomli 2.0.1 torch 2.1.2+cu118 torchaudio 2.1.2+cu118 torchvision 0.16.2+cu118 tornado 6.4.1 tox 3.28.0 tqdm 4.66.4 traitlets 5.14.3 transformers 4.41.1 triton 2.1.0 trl 0.8.6 typer 0.12.3 types-python-dateutil 2.9.0.20240316 typing_extensions 4.12.2 tyro 0.8.5 tzdata 2024.1 ujson 5.10.0 unattended-upgrades 0.1 uri-template 1.3.0 urllib3 1.26.19 uvicorn 0.30.1 uvloop 0.19.0 virtualenv 20.26.3 vllm 0.4.0+cu118 watchdog 4.0.1 watchfiles 0.22.0 wcwidth 0.2.13 webcolors 24.6.0 webencodings 0.5.1 websocket-client 1.8.0 websockets 12.0 Werkzeug 3.0.3 wheel 0.43.0 widgetsnbextension 4.0.11 wrapt 1.16.0 wsproto 1.2.0 xformers 0.0.23.post1+cu118 xxhash 3.4.1 y-py 0.6.2 yarl 1.9.4 ypy-websocket 0.8.4 zipp 3.19.2

xiamengzhou commented 4 months ago

@RikkiXu @AGTSAAA Another important note is to use this version of AE2 for evaluation. Recent updates might affect the results, as detailed in this issue.

I am looking into the Mistral-Instruct reproducibility issue, and will get back to you soon.

xiamengzhou commented 4 months ago

@RikkiXu

Hi! I managed to use the most updated repo to run experiments with multiple different seeds, and got the following results.

Model Version LC Win Rate Win Rate Std Length N
Reproduction mistral-7b-instruct-simpo-seed0 mistral instruct 30.4 34.2 1.4 2342 805
Reproduction mistral-7b-instruct-simpo-seed1 mistral instruct 30.7 34.7 1.4 2443 805
Reproduction mistral-7b-instruct-simpo-seed2 mistral instruct 30.0 34.3 1.4 2420 805
Reported in SimPO Paper mistral instruct 32.1 34.8 1.4 2193 805
Your run mistral instruct 27.6 29.3 1.3 2352 805

It appears my reproduction differs slightly from our reported results and has a relatively longer length. We are not sure why it is the case :( and are still looking into it. However, overall, these runs are better than your results. Additionally, it seems that you are using alpaca_eval==0.6.3. We noticed a discrepancy between versions 0.6.2 and 0.6.3 regarding VLLM decoding, which causes performance degradation. Could you please use alpaca_eval==0.6.2 for decoding?

Best, Mengzhou

junkangwu commented 3 months ago

This could be related to the tokenizer:

Align tokenizer with mistral-common (#141)

xiamengzhou commented 3 months ago

@junkangwu thanks for identifying this -- We indeed used the older version of the tokenizer in our previous version!