ContextualAI / HALOs

A library with extensible implementations of DPO, KTO, PPO, ORPO, and other human-aware loss functions (HALOs).
https://arxiv.org/abs/2402.01306
Apache License 2.0
696 stars 39 forks source link

ERROR:None of the inputs have requires_grad=True. Gradients will be None #14

Closed Pattaro closed 5 months ago

Pattaro commented 6 months ago

Computing eval metrics: 0%| | 0/86 [00:00<?, ?it/s]/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

Pattaro commented 6 months ago

-- Process 0 terminated with the following error: Traceback (most recent call last): File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, *args) File "/checkpoint/binary/train_package/train.py", line 84, in worker_main trainer.train() File "/checkpoint/binary/train_package/trainers.py", line 408, in train (loss / self.config.model.gradient_accumulation_steps).backward() File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward torch.autograd.backward( File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Pattaro commented 6 months ago

but while printing the requires_grad of model.state_dict() , all of params's are TRUE

kawine commented 6 months ago

Can you provide more context and write the command you ran here?

  1. What dataset are you using? Do you get the same error with one of the default datasets?
  2. Did you write a new trainer, dataloader, etc.?
  3. Are you using the same conda env provided in the package?
Pattaro commented 6 months ago

[requirement] The base model I ran is Qwen14b, and to run Qwen14b, I used the following requirement: hydra-core==1.3.2 transformers==4.38.2 tensor_parallel datasets==2.16.1 wandb transformers_stream_generator einops accelerate trl flash-attn deepspeed tiktoken python==3.10 pytorch==2.0.1 cuda==11.7

[datasets] I used a custom dataset and wrote a new dataloader for it. 2C8B3358-74B1-40C0-AB2B-0BF43E7C8D46 02011159-FCA1-478D-8D82-78447EFD530C However, when conducting SFT, I used the Llama7b model with the HH dataset. It still prompts the warning: "UserWarning: None of the inputs have requires_grad=True. Gradients will be None warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")", but it does not throw an error: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn".

[trainer] i used ktotrainer

Pattaro commented 6 months ago

and I foundAdditionally, I noticed that during the training process, the parameters of the policy model did not change

kawine commented 6 months ago

Can you create a conda env from environment.yaml? Small differences in package versioning can affect whether cuda works with FSDP / flash attention.

Secondly, I don't think you have a preference dataset -- you have a binary preference dataset, so you should be populating the Example.desirable field with [True, True, False, True, ... ], not the Example.pairs field (which is for preferences only).

Pattaro commented 6 months ago

Regarding the second point, I think it's fine. In my implementation, it's possible to ensure that we correctly identify which response is chosen and which is not in the get_flat_data part of the dataloader.

My environment.yaml is as follow: absl-py==2.1.0 accelerate==0.27.2 acm-sdk-python==0.4.11 ai-scheduler @ http://shanzhongpub.oss-cn-hangzhou-zmf.aliyuncs.com/ai-scheduler/ai_scheduler-0.2_detect_pod_hung-py3-none-any.whl#sha256=3c73b9a3bf0a7c0bdc2d92739a700da3eae215264560a8b71bf2844d6b45ce51 aiohttp==3.9.3 aiosignal==1.3.1 aliyun-python-sdk-core==2.14.0 aliyun-python-sdk-kms==2.16.2 antlr4-python3-runtime==4.9.3 apex @ file:///apex-23.08 appdirs==1.4.4 astunparse==1.6.3 async-timeout==4.0.3 attrs==23.2.0 av==11.0.0 bitsandbytes==0.41.0 certifi==2024.2.2 cffi==1.16.0 chardet==3.0.4 charset-normalizer==3.3.2 click==8.1.7 cmake==3.28.3 common-io @ http://nebula-cv.oss-cn-zhangjiakou.aliyuncs.com/docker_image/common_io-0.8.7-cp310-cp310-linux_x86_64.whl#sha256=bdb1419d18d463536007115fbd4439941a235f6e7a823241f6fc0789c71bb450 crcmod==1.7 cryptography==42.0.5 cycler==0.12.1 Cython==3.0.8 datasets==2.16.1 deepspeed==0.8.2 dill==0.3.7 docker-pycreds==0.4.0 docopt==0.6.2 docstring-parser==0.15 easydict==1.12 einops==0.7.0 filelock==3.13.1 flash-attn==1.0.9 frozenlist==1.4.1 fsspec==2023.10.0 future==0.18.2 gitdb==4.0.11 GitPython==3.1.42 GPUtil==1.4.0 grpcio==1.62.0 hdfs==2.7.3 hjson==3.1.0 huggingface-hub==0.21.3 hydra-core==1.3.2 idna==2.8 intel-openmp==2024.0.2 jieba==0.42.1 Jinja2==3.1.3 jmespath==0.10.0 joblib==1.3.2 kazoo==2.9.0 kiwisolver==1.4.5 kmontitor-client==0.0.0 lake_py_lib @ http://ziying-dl.oss-cn-hangzhou-zmf.aliyuncs.com/draft/lake_py_lib-ziying-0.1.10-py3-none-any.whl#sha256=b7eb0a34be7f2324a7ad52f321270bdd10d3097ccb3495561062199afde82516 lightning-utilities==0.10.1 lit==17.0.6 lmdb==0.94 lru-dict==1.3.0 Markdown==3.5.2 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.3.4 mdl @ file:///mdl/dist/mdl-0.2-py3-none-any.whl#sha256=4fa40dd0c42380926168d50ad2bbb8060bc35c8528ac6d171798b5b65962d6a5 mdurl==0.1.2 mkl==2024.0.0 mkl-include==2024.0.0 mpmath==1.3.0 multidict==6.0.5 multiprocess==0.70.15 nebula-mos-python-sdk @ http://yum.tbsite.net/aliyun-pypi/packages/nebula-mos-python-sdk/nebula_mos_python_sdk-0.3.16-py3-none-any.whl#sha256=f4d5efaa845830db515e70a58ffcd18772b3f1685b2a67885cbb18e5b24a7a4c nebula-py-pangu-early-test @ http://yum.tbsite.net/aliyun-pypi/packages/nebula-py-pangu-early-test/nebula_py_pangu_early_test-0.0.50-py3-none-any.whl#sha256=63eb2e21b282b4d5db8f582a5cfcb149850bee4718e22d1db449965e93b4188a networkx==3.2.1 ninja==1.11.1.1 numpy==1.21.5 nvidia-ml-py3==7.352.0 omegaconf==2.3.0 opencv-python==4.5.4.60 oss2 @ https://nebula-cv.oss-cn-zhangjiakou.aliyuncs.com/docker_image/oss2/oss2-2.16.0-py3-none-any.whl#sha256=d8b1db35ca677860c4f3d253acba1e523e8b6a045d5e8d38a91694618a53f122 packaging==23.2 pandas==1.1.5 Pillow==8.4.0 protobuf==3.20.1 psutil==5.9.5 py-cpuinfo==9.0.0 py-spy==0.3.14 pyarrow==15.0.0 pyarrow-hotfix==0.6 pybind11==2.11.1 pycparser==2.21 pycryptodome==3.20.0 pydantic==1.10.9 pydicom==1.2.2 Pygments==2.17.2 pykmonitor @ http://xdl2-image-deps.oss-cn-hangzhou-zmf.aliyuncs.com/ai_scheduler%2Fwheels%2Fpykmonitor-1.0-py3-none-any.whl#sha256=5aa70ffcf6631af1ac47d1e338329c72f02c9c0407d72f9d5a246b72484074a6 pynvml==11.5.0 pyodps-int @ git+http://gitlab-sc.alibaba-inc.com/odps/pyodps.git@b7124a312fdc162762f2a7ead6e22d75ff7b5eab pyparsing==3.1.1 pytest-runner==6.0.1 python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 redis==5.0.2 regex==2023.12.25 requests==2.31.0 retrying==1.3.4 rich==13.7.1 safetensors==0.4.2 scikit-learn==1.4.1.post1 scipy==1.7.3 sentencepiece==0.1.96 sentry-sdk==1.40.6 setproctitle==1.3.3 shtab==1.7.0 simplejson==3.17.6 six==1.16.0 sklearn==0.0.post12 smmap==5.0.1 sympy==1.12 tbb==2021.11.0 tensor-parallel==2.0.0 tensorboard==2.16.2 tensorboard-data-server==0.7.2 thop-statistics @ https://nebula-cv.oss-cn-zhangjiakou.aliyuncs.com/docker_image/thop_statistics-0.1.1.post2303141613-py3-none-any.whl#sha256=1bcb3a356b8cab94d7a67b7801e70b501d4518f298cb12258762fea4e84a77c3 threadpoolctl==3.3.0 thrift==0.16.0 tiktoken==0.6.0 tokenizers==0.15.2 torch @ http://nebula-cv.oss-cn-zhangjiakou.aliyuncs.com/docker_image/pytorch/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl#sha256=bb54b705185bea820e6ec6485a25761bc03f689e1a09a37d814d6ea8e276b5bd torchaudio==2.0.2+cu117 torchmetrics==1.3.1 torchvision @ http://nebula-cv.oss-cn-zhangjiakou.aliyuncs.com/docker_image/pytorch/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl#sha256=1ee57f2bee878ad8574ea559bb7172c1cfaad168634fa738479e1fe3bdd7eaca tornado==6.1 tqdm==4.66.2 transformers==4.38.2 transformers-stream-generator==0.0.4 transitions==0.9.0 triton==2.0.0 trl==0.7.11 typing_extensions==4.10.0 tyro==0.7.3 urllib3==2.2.1 wandb==0.16.3 Werkzeug==3.0.1 xformers==0.0.21 xxhash==3.4.1 yarl==1.9.4

kawine commented 6 months ago

I see some differences in the dependency versions. Can you create a new conda environment based on the yaml file in the repo? i.e.

conda env create -f environment.yaml

Then try training mistral7b or a llama model with the anthropic dataset. Do you still get an error?

Pattaro commented 6 months ago

Due to equipment limitations, the above is the closest environment I can achieve. T.T

kawine commented 6 months ago

Ah okay. I suspect that this is due to versioning conflicts b/w Pytorch FSDP, transformers, and flash-attention (i vaguely remember getting an error like this at one point). It's a pain getting all three of these to work harmoniously.

Good news is that a Huggingface TRL implementation of KTO should be available soon (there is one already, it's just a little buggy), and that has Deepspeed support (allowing you to avoid fsdp) and better flash-attention integration.

I'll post that here when it's good to go.

Pattaro commented 6 months ago

I once completed the training by adding loss.requiresgrad(True), but the policy.pt saved after training for 1 epoch is exactly the same as the original weight.What's the problem?

xwinxu commented 6 months ago

The Huggingface TRL implementation of KTOTrainer has now been merged. See this example script to get started. I have tested KTO training with Deepspeed and launching with accelerate on multiple nodes recently. Things appear to work well. Note also that in the most recent transformers version, no flash attention is needed (and the updated package should run faster anyways). Let us know if you try it out!