Closed RikkiXu closed 3 months ago
@RikkiXu Please check out the hyperparameters here: https://github.com/princeton-nlp/SimPO/blob/main/README.md#hyperparameter-tuning
hi
@xiamengzhou Thanks for your update, but I still can't reproduce it successfully. I used alignment-handbook, configs are as follows, alpaca-eval is only about 18, is there any problem? I am experimenting on 8 A100s
model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2 torch_dtype: null dataset_mixer: princeton-nlp/mistral-instruct-ultrafeedback: 1.0 dataset_splits:
@RikkiXu It seems that you are using max_length: 1024, max_prompt_length: 512 in your script. We use 2048 and 1800 respectively in our script. Hope it helps address the issue!
Hi @xiamengzhou Thanks for your reply. I reproduced Mistral-7B-Instruct-SimPO and Mistral-7B-Instruct-DPO strictly according to the latest repo. Unfortunately, neither of them achieved the effect of the paper.
For Mistral-7B-Instruct-SimPO, the only change I made was to change 4 gpts to 8 gpts, and set gradient_accumulation_steps to 8, so that batch size=128. The final alpacaeval2 result of Mistral-7B-Instruct-SimPO is as follows:
|
length_controlled_winrate | win_rate | standard_error | n_total | avg_length |
---|---|---|---|---|---|
Mistral-7B-Instruct-simpo | 27.55 | 29.28 | 1.32 | 805 | 2352 |
For Mistral-7B-Instruct-DPO, the only change I made was to replace 251-258 of run_simpo.py with:
from trl import DPOTrainer
ref_model = model
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
The result of Mistral-7B-Instruct-DPO is as follows:
|
length_controlled_winrate | win_rate | standard_error | n_total | avg_length |
---|---|---|---|---|---|
Mistral-7B-Instruct-dpo | 21.51 | 17.70 | 1.15 | 805 | 1536 |
When I directly eval the checkpoint you published, I can get alpacaeval2 similar to the paper, which are 32 and 26 respectively. But when I trained myself, Mistral-7B-Instruct-DPO and Mistral-7B-Instruct-SimPO both dropped 5 points in Alpacaeval2. This situation only occurred when reproducing Mistral-7B-Instruct, and I was able to successfully reproduce Mistral-7B-Base. Could you please check it? Thank you very much for your reply again.
Hi, @RikkiXu,
Could you please share your version of packages (pip list)? I can not even successfully reproduce Mistral-7B-Base
Thank you very much.
Hi, @RikkiXu,
Could you please share your version of packages (pip list)? I can not even successfully reproduce Mistral-7B-Base
Thank you very much.
Package Version
absl-py 2.1.0 accelerate 0.26.1 aiofiles 22.1.0 aiohttp 3.9.5 aiosignal 1.3.1 aiosqlite 0.20.0 alabaster 0.7.16 alignment-handbook 0.4.0.dev0 alpaca_eval 0.6.3 annotated-types 0.7.0 anyio 4.4.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 async-timeout 4.0.3 attrs 23.1.0 Babel 2.15.0 beautifulsoup4 4.12.3 bidict 0.23.1 bitsandbytes 0.42.0 bleach 6.1.0 byted_remote_ikernel 0.4.8 byted-torch 2.1.0.post2 byted_torch_monitor 0.0.1 byted-wandb 0.13.72 bytedance-context 0.7.1 bytedance-metrics 0.5.1 bytedance.modelhub 0.0.70 bytedance.servicediscovery 0.1.2 bytedbackgrounds 0.0.6 byteddatabus 1.0.6 byteddps 0.1.2 bytedenv 0.6.2 bytedlogger 0.15.1 bytedmemfd 0.2 bytedmetrics 0.10.2 bytedpymongo 2.0.5 bytedrh2 1.18.9a11 bytedservicediscovery 0.17.4 bytedtcc 1.4.2 bytedtos 1.1.16 bytedtrace 0.3.0 bytedztijwthelper 0.0.23 bytedztispiffe 0.0.14 certifi 2024.6.2 cffi 1.16.0 chardet 5.2.0 charset-normalizer 3.3.2 click 8.1.7 cloudpickle 3.0.0 cmake 3.30.0 comm 0.2.2 crcmod 1.7 cryptography 42.0.8 Cython 3.0.10 datasets 2.16.1 dbus-python 1.2.16 debugpy 1.8.2 decorator 5.1.1 deepspeed 0.12.2 defusedxml 0.7.1 Deprecated 1.2.14 dill 0.3.7 diskcache 5.6.3 distlib 0.3.8 distro 1.9.0 distro-info 1.0+deb11u1 dnspython 2.6.1 docker-pycreds 0.4.0 docstring_parser 0.16 docutils 0.19 einops 0.8.0 email_validator 2.2.0 entrypoints 0.4 eval_type_backport 0.2.0 evaluate 0.4.1 exceptiongroup 1.2.1 executing 2.0.1 fastapi 0.111.0 fastapi-cli 0.0.4 fastjsonschema 2.20.0 filelock 3.15.1 fire 0.6.0 flash-attn 2.5.6 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2023.10.0 gitdb 4.0.11 GitPython 3.1.43 grpcio 1.64.1 h11 0.14.0 hf_transfer 0.1.6 hjson 3.1.0 httpcore 1.0.5 httptools 0.6.1 httpx 0.27.0 huggingface-hub 0.23.4 idna 3.7 imagesize 1.4.1 importlib_metadata 7.1.0 iniconfig 2.0.0 interegular 0.3.3 iotop 0.6 ipaddress 1.0.23 ipykernel 6.29.5 ipython 8.18.1 ipython-genutils 0.2.0 ipywidgets 8.1.3 isoduration 20.11.0 jedi 0.19.1 Jinja2 3.1.4 joblib 1.4.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter-client 7.0.0 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter_server 2.14.1 jupyter_server_fileid 0.9.2 jupyter_server_terminals 0.5.3 jupyter_server_ydoc 0.8.0 jupyter-ydoc 0.2.5 jupyterlab 3.6.4 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.2 jupyterlab_widgets 3.0.11 lark 1.1.9 llvmlite 0.43.0 Markdown 3.6 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib-inline 0.1.7 mdurl 0.1.2 mistune 3.0.2 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.15 nbclassic 1.1.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.2.1 ninja 1.11.1.1 none 0.1.1 notebook 6.5.7 notebook_shim 0.2.4 numba 0.60.0 numpy 1.26.4 openai 1.35.13 orjson 3.10.6 outlines 0.0.34 overrides 7.7.0 packaging 24.1 pandas 2.2.2 pandocfilters 1.5.1 parso 0.8.4 pathlib2 2.3.7.post1 pathtools 0.1.2 patsy 0.5.6 peft 0.8.2 pexpect 4.8.0 pillow 10.2.0 pip 24.1.2 platformdirs 4.2.2 pluggy 1.5.0 ply 3.11 prometheus_client 0.20.0 promise 2.3 prompt_toolkit 3.0.47 protobuf 3.20.2 psutil 6.0.0 ptyprocess 0.7.0 pure-eval 0.2.2 py 1.11.0 py-cpuinfo 9.0.0 py-spy 0.3.14 pyarrow 16.1.0 pyarrow-hotfix 0.6 pycparser 2.22 pycryptodomex 3.20.0 pycurl 7.43.0.6 pydantic 2.8.2 pydantic_core 2.20.1 Pygments 2.18.0 PyGObject 3.38.0 PyJWT 2.8.0 pynvml 11.5.0 pyOpenSSL 24.1.0 pytest 6.2.5 python-apt 2.2.1 python-consul 1.1.0 python-dateutil 2.9.0.post0 python-dotenv 1.0.1 python-engineio 4.9.1 python-etcd 0.4.5 python-json-logger 2.0.7 python-multipart 0.0.9 python-socketio 5.11.3 pytz 2024.1 PyYAML 6.0.1 pyzmq 26.0.3 qtconsole 5.5.2 QtPy 2.4.1 ray 2.32.0 referencing 0.35.1 regex 2024.5.15 requests 2.32.3 responses 0.18.0 rfc3339-validator 0.1.4 rfc3986 2.0.0 rfc3986-validator 0.1.1 rich 13.7.1 rpds-py 0.19.0 safetensors 0.4.3 schedule 1.2.2 scikit-learn 1.5.1 scipy 1.13.1 Send2Trash 1.8.3 sentencepiece 0.2.0 sentry-sdk 2.9.0 setproctitle 1.3.3 setuptools 69.5.1 shellingham 1.5.4 shortuuid 1.0.13 shtab 1.7.1 simple-websocket 1.0.0 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 snowballstemmer 2.2.0 soupsieve 2.5 Sphinx 5.3.0 sphinxcontrib-applehelp 1.0.8 sphinxcontrib-devhelp 1.0.6 sphinxcontrib-htmlhelp 2.0.5 sphinxcontrib-jsmath 1.0.1 sphinxcontrib-qthelp 1.0.7 sphinxcontrib-serializinghtml 1.1.10 sphinxcontrib-websupport 1.2.7 stack-data 0.6.3 starlette 0.37.2 sympy 1.12.1 tensorboard 2.17.0 tensorboard-data-server 0.7.2 termcolor 2.4.0 terminado 0.18.1 threadpoolctl 3.5.0 thriftpy2 0.5.2 tiktoken 0.6.0 tinycss2 1.3.0 tokenizers 0.19.1 toml 0.10.2 tomli 2.0.1 torch 2.1.2+cu118 torchaudio 2.1.2+cu118 torchvision 0.16.2+cu118 tornado 6.4.1 tox 3.28.0 tqdm 4.66.4 traitlets 5.14.3 transformers 4.41.1 triton 2.1.0 trl 0.8.6 typer 0.12.3 types-python-dateutil 2.9.0.20240316 typing_extensions 4.12.2 tyro 0.8.5 tzdata 2024.1 ujson 5.10.0 unattended-upgrades 0.1 uri-template 1.3.0 urllib3 1.26.19 uvicorn 0.30.1 uvloop 0.19.0 virtualenv 20.26.3 vllm 0.4.0+cu118 watchdog 4.0.1 watchfiles 0.22.0 wcwidth 0.2.13 webcolors 24.6.0 webencodings 0.5.1 websocket-client 1.8.0 websockets 12.0 Werkzeug 3.0.3 wheel 0.43.0 widgetsnbextension 4.0.11 wrapt 1.16.0 wsproto 1.2.0 xformers 0.0.23.post1+cu118 xxhash 3.4.1 y-py 0.6.2 yarl 1.9.4 ypy-websocket 0.8.4 zipp 3.19.2
@RikkiXu @AGTSAAA Another important note is to use this version of AE2 for evaluation. Recent updates might affect the results, as detailed in this issue.
I am looking into the Mistral-Instruct reproducibility issue, and will get back to you soon.
@RikkiXu
Hi! I managed to use the most updated repo to run experiments with multiple different seeds, and got the following results.
Model | Version | LC Win Rate | Win Rate | Std | Length | N |
---|---|---|---|---|---|---|
Reproduction mistral-7b-instruct-simpo-seed0 | mistral instruct | 30.4 | 34.2 | 1.4 | 2342 | 805 |
Reproduction mistral-7b-instruct-simpo-seed1 | mistral instruct | 30.7 | 34.7 | 1.4 | 2443 | 805 |
Reproduction mistral-7b-instruct-simpo-seed2 | mistral instruct | 30.0 | 34.3 | 1.4 | 2420 | 805 |
Reported in SimPO Paper | mistral instruct | 32.1 | 34.8 | 1.4 | 2193 | 805 |
Your run | mistral instruct | 27.6 | 29.3 | 1.3 | 2352 | 805 |
It appears my reproduction differs slightly from our reported results and has a relatively longer length. We are not sure why it is the case :( and are still looking into it. However, overall, these runs are better than your results. Additionally, it seems that you are using alpaca_eval==0.6.3. We noticed a discrepancy between versions 0.6.2 and 0.6.3 regarding VLLM decoding, which causes performance degradation. Could you please use alpaca_eval==0.6.2
for decoding?
Best, Mengzhou
This could be related to the tokenizer:
Align tokenizer with mistral-common (#141)
@junkangwu thanks for identifying this -- We indeed used the older version of the tokenizer in our previous version!
hi, can I ask about the beta and leaning rate of Mistral-7B-Instruct-DPO? I can't reproduce the results in the paper.