Open HideLord opened 4 months ago
Just tested Phi3Medium, and it is working as expected:
The difference is that I ran medium with ds_z3_config.json
Observed the same issue.
Hi, I have observed the same issue, has it been resolved?
When I chat with Phi-3-Small, the model often fails to predict the stop token. Perhaps the chat template for Phi-3-small is wrong? Similar issue can be found here: #4712
Reminder
Reproduction
To trigger the issue, I tried to train Phi-3-small using LoRA on 4 GPUs using deepspeed with ds_z2_config. Full yaml config:
To trigger it:
The gradient explodes at some point:
The training also breaks with OOM:
As a comparison, I ran the same configuration but with mistral-7b:
May be it's related to the fact that Phi-3-small uses a different architecture than the rest of the family? Phi3SmallForCausalLM vs Phi3ForCausalLM.
Expected behavior
The gradient should stay consistent.
System Info
accelerate==0.30.1 addict==2.4.0 aiofiles==23.2.1 aiohttp==3.9.5 aiosignal==1.3.1 aliyun-python-sdk-core==2.15.1 aliyun-python-sdk-kms==2.16.3 altair==5.3.0 annotated-types==0.6.0 anyio==4.3.0 aqlm==1.1.5 async-timeout==4.0.3 attrs==23.2.0 auto_gptq==0.7.1 autoawq==0.2.5 autoawq_kernels==0.0.6 bitsandbytes==0.43.1 certifi==2024.2.2 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 cloudpickle==3.0.0 cmake==3.29.3 coloredlogs==15.0.1 contourpy==1.2.1 crcmod==1.7 cryptography==42.0.7 cycler==0.12.1 datasets==2.18.0 deepspeed==0.14.0 dill==0.3.8 diskcache==5.6.3 distro==1.9.0 dnspython==2.6.1 docker-pycreds==0.4.0 docstring_parser==0.16 einops==0.8.0 email_validator==2.1.1 exceptiongroup==1.2.1 fastapi==0.111.0 fastapi-cli==0.0.3 ffmpy==0.3.2 filelock==3.14.0 fire==0.6.0 flash-attn==2.5.8 fonttools==4.51.0 frozenlist==1.4.1 fsspec==2024.2.0 gast==0.5.4 gekko==1.1.1 gitdb==4.0.11 GitPython==3.1.43 gradio==4.31.4 gradio_client==0.16.4 h11==0.14.0 hjson==3.1.0 httpcore==1.0.5 httptools==0.6.1 httpx==0.27.0 huggingface-hub==0.23.0 humanfriendly==10.0 idna==3.7 importlib_metadata==7.1.0 importlib_resources==6.4.0 iniconfig==2.0.0 interegular==0.3.3 jieba==0.42.1 Jinja2==3.1.4 jmespath==0.10.0 joblib==1.4.2 jsonschema==4.22.0 jsonschema-specifications==2023.12.1 kiwisolver==1.4.5 lark==1.1.9 -e git+https://github.com/hiyouga/LLaMA-Factory.git@419d47c101eab27dbffc0e3bd646c7e43d036fb3#egg=llamafactory llmtuner==0.6.3.dev0 llvmlite==0.42.0 lm-format-enforcer==0.9.8 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 mdurl==0.1.2 modelscope==1.14.0 mpmath==1.3.0 msgpack==1.0.8 multidict==6.0.5 multiprocess==0.70.16 nest-asyncio==1.6.0 networkx==3.3 ninja==1.11.1.1 nltk==3.8.1 numba==0.59.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-ml-py==12.550.52 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 openai==1.30.1 optimum==1.19.2 orjson==3.10.3 oss2==2.18.5 outlines==0.0.34 packaging==24.0 pandas==2.2.2 peft==0.11.1 pillow==10.3.0 platformdirs==4.2.2 pluggy==1.5.0 prometheus-fastapi-instrumentator==7.0.0 prometheus_client==0.20.0 protobuf==4.25.3 psutil==5.9.8 py-cpuinfo==9.0.0 pyarrow==16.1.0 pyarrow-hotfix==0.6 pycparser==2.22 pycryptodome==3.20.0 pydantic==2.7.1 pydantic_core==2.18.2 pydub==0.25.1 Pygments==2.18.0 pynvml==11.5.0 pyparsing==3.1.2 pytest==8.2.1 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 python-multipart==0.0.9 pytz==2024.1 PyYAML==6.0.1 ray==2.22.0 referencing==0.35.1 regex==2024.5.15 requests==2.31.0 rich==13.7.1 rouge==1.0.1 rouge-chinese==1.0.3 rpds-py==0.18.1 ruff==0.4.4 safetensors==0.4.3 scipy==1.13.0 semantic-version==2.10.0 sentencepiece==0.2.0 sentry-sdk==2.2.0 setproctitle==1.3.3 shellingham==1.5.4 shtab==1.7.1 simplejson==3.19.2 six==1.16.0 smmap==5.0.1 sniffio==1.3.1 sortedcontainers==2.4.0 sse-starlette==2.1.0 starlette==0.37.2 sympy==1.12 termcolor==2.4.0 tiktoken==0.6.0 tokenizers==0.19.1 tomli==2.0.1 tomlkit==0.12.0 toolz==0.12.1 torch==2.3.0 tqdm==4.66.4 transformers==4.41.1 transformers-stream-generator==0.0.5 triton==2.3.0 trl==0.8.6 typer==0.12.3 typing_extensions==4.11.0 tyro==0.8.4 tzdata==2024.1 ujson==5.10.0 urllib3==2.2.1 uvicorn==0.29.0 uvloop==0.19.0 vllm==0.4.2 vllm-nccl-cu12==2.18.1.0.4.0 wandb==0.17.0 watchfiles==0.21.0 websockets==11.0.3 xformers==0.0.26.post1 xxhash==3.4.1 yapf==0.40.2 yarl==1.9.4 zipp==3.18.2 zstandard==0.22.0
Others
No response