Open 647sherry opened 4 months ago
@647sherry Our training was conducted on 8 A100-80G gpus, which is two times larger than your setting. For larger models, you could try reducing per_device_train_batch_size as needed, and increase gradient_accumulation_steps, such that per_device_train_batch_size * gradient_accumulation_steps = 8.
@647sherry Our training was conducted on 8 A100-80G gpus, which is two times larger than your setting. For larger models, you could try reducing per_device_train_batch_size as needed, and increase gradient_accumulation_steps, such that per_device_train_batch_size * gradient_accumulation_steps = 8.
thx for ur reply. I set per_device_train_batch_size = 1 and gradient_accumulation_steps = 8, with zero_stage2&3, but still oom
I ran into a similar issue with a deepseek-math-7b model on a100 80G GPU, cannot get things to work with er_device_train_batch_size = 1 and gradient_accumulation_steps = 1, I suspect there is some model related bug there.
@jojo23333 Is that OOM error or other issues? Can you check the GPU utilization while you are not running this training to see if it has been taken by other processes? From what I know OOM under your setting is not possible.
Hi, thanks so much for the reply! I'm pretty sure all memory was eaten up by this single process. I did a very detailed verification and see the forward process and model loading takes up to 30 ishGB/80GB and the backward process immediately leads to OOM.
However, that being said, I had a kind of different setup.
I'm not super sure whether this has something to do with the response/context length for a single sample, I roughly estimate that the response length is at most 2x longer than what it is in the superchat dataset. But still, cannot get batch_size=1 working is somewhat wierd.
Actually I tried zephyr-7b-sft-full with the orginal setting on my data, I was able to get training going on with per-device batch size = 8, but not with deepseek some how.
Hi @jojo23333 Have you solved the OOM? This might be an related issue https://github.com/huggingface/transformers/issues/29484
thanks for the pointer, I'll take a look
hi, I got OOM error while fine tuning with qwen-14b-chat and the default model. using
accelerate launch --config_file configs/deepspeed_zero3.yaml --multi_gpu --num_processes=8 --main_process_port 29501 spin/run_spin.py configs/config.yaml --num_train_epochs=3 --output_dir="xxx/spin_outputs/iter0-ckpt"
system info
absl-py 2.1.0 accelerate 0.23.0 aiohttp 3.9.5 aioprometheus 23.12.0 aiosignal 1.3.1 annotated-types 0.7.0 anyio 4.4.0 async-timeout 4.0.3 attrs 23.2.0 bitsandbytes 0.41.2.post2 certifi 2024.6.2 charset-normalizer 3.3.2 click 8.1.7 cloudpickle 3.0.0 cmake 3.29.6 contourpy 1.2.1 cycler 0.12.1 datasets 2.14.6 deepspeed 0.12.2 dill 0.3.7 diskcache 5.6.3 dnspython 2.6.1 docstring_parser 0.16 einops 0.8.0 email_validator 2.1.1 evaluate 0.4.0 exceptiongroup 1.2.1 fastapi 0.111.0 fastapi-cli 0.0.4 filelock 3.15.1 flash_attn 2.5.9.post1 fonttools 4.53.0 frozenlist 1.4.1 fsspec 2023.10.0 grpcio 1.64.1 h11 0.14.0 hjson 3.1.0 httpcore 1.0.5 httptools 0.6.1 httpx 0.27.0 huggingface-hub 0.23.3 idna 3.7 interegular 0.3.3 Jinja2 3.1.4 joblib 1.4.2 jsonlines 4.0.0 jsonschema 4.22.0 jsonschema-specifications 2023.12.1 kiwisolver 1.4.5 lark 1.1.9 llvmlite 0.43.0 Markdown 3.6 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib 3.9.0 mdurl 0.1.2 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.15 nest-asyncio 1.6.0 networkx 3.3 ninja 1.11.1.1 numba 0.60.0 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.18.1 nvidia-nvjitlink-cu12 12.5.40 nvidia-nvtx-cu12 12.1.105 opencv-python 4.10.0.84 orjson 3.10.5 outlines 0.0.34 packaging 24.1 pandas 2.2.2 peft 0.6.1 pillow 10.4.0 pip 24.0 prometheus_client 0.20.0 protobuf 3.20.2 psutil 5.9.8 py-cpuinfo 9.0.0 py4j 0.10.9.7 pyarrow 16.1.0 pydantic 2.7.4 pydantic_core 2.18.4 Pygments 2.18.0 pynvml 11.5.0 pyparsing 3.1.2 pyspark 3.5.1 python-dateutil 2.9.0.post0 python-dotenv 1.0.1 python-multipart 0.0.9 pytz 2024.1 PyYAML 6.0.1 quantile-python 1.1 ray 2.24.0 referencing 0.35.1 regex 2024.5.15 requests 2.32.3 responses 0.18.0 rich 13.7.1 rpds-py 0.18.1 safetensors 0.4.3 scipy 1.13.1 seaborn 0.13.2 sentencepiece 0.2.0 setuptools 69.5.1 shellingham 1.5.4 shtab 1.7.1 six 1.16.0 sniffio 1.3.1 spin 0.1.0.dev0 starlette 0.37.2 sympy 1.12.1 tensorboard 2.17.0 tensorboard-data-server 0.7.2 tiktoken 0.6.0 tokenizers 0.15.2 torch 2.1.0 torchvision 0.18.1 tqdm 4.66.4 transformers 4.36.2 transformers-stream-generator 0.0.5 triton 2.1.0 trl 0.7.4 typer 0.12.3 typing_extensions 4.12.2 tyro 0.8.4 tzdata 2024.1 ujson 5.10.0 ultralytics-thop 2.0.0 urllib3 2.2.1 uvicorn 0.30.1 uvloop 0.19.0 vllm 0.3.0 watchfiles 0.22.0 websockets 12.0 Werkzeug 3.0.3 wheel 0.43.0 xformers 0.0.23.post1 xxhash 3.4.1 yarl 1.9.4
Thanks for your help in advance!