Closed mrsempress closed 3 months ago
@mrsempress hey, thanks for raising the issue.
on the second problem: the root cause is probably the one identified in https://github.com/stanfordnlp/pyreft/issues/70 -- tensorboard is not well integrated yet.
as a result, you need to make sure to run your commend with --report_to none
, or --report_to wandb
.
@mrsempress on the first problem, could you make sure there is only 1 GPU visible on your machine (i know you ran with CUDA_VISIABLE_DEVICES=6
but this problem usually occur when there are multiple GPUs); our script does not support multi-gpu training well at this point yet.
Sorry, the first issue has been updated. The first issue is that bfloat16 cannot be used. For issue 2, I did not see the appearance of tensorboard in main_demo.ipynb, and there is no argparse, so I do not understand why.
@mrsempress thanks.
For issue 1) could you provide your running script?
For issue 2) could you reproduce this error by running the notebook on google colab, and share the error'd colab with me?
Thanks! These will help me to root cause the issues here.
For issue 1) CUDA_VISIBLE_DEVICES=0 python examples/loreft/train.py -task gsm8k -model ../../models/Llama-7b-hf -seed 42 -l all -r 4 -p f7+l7 -e 12 -lr 9e-4 -type NodireftIntervention -gradient_accumulation_steps 4 -batch_size 8 -eval_batch_size 4 --dropout 0.05 --test_split validation --use_normalized_template --greedy_decoding --warmup_ratio 0.00 --weight_decay 0.06 Additionally, I would like to know how much memory is occupied by hyperparameter tuning and training in Loreft. Because I use hyperparameter tuning, it takes up over 60 GB of memory. I want to know if it was only caused by changing bfloat to float32. Also, I would like to know how long the training time is usually?
For issue 2), as I reproduce it successfully, the link will be updated. Now when installing pyreft, Colab will prompt "you must restart the runtime in order to use newly installed versions", which will take some time. I only used the original ipynb without modifying the code, so you can also try the experiment. I am not sure if it is due to machine environment issues.
@mrsempress Thanks.
I want to know if it was only caused by changing bfloat to float32
Could you explain more about the change? Did you change examples/loreft/train.py
? And what is the change?
For issue 2), i attached my local notebook which does not encounter this issue: main_demo.pdf
Could you check the version of your transformers
library? Could you install 4.39.3
version and try again? It is mostly likely a env/set-up issue since all my experiments are running just fine.
@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:
Please go to the logs, and trace out other details.
@mrsempress Thanks.
I want to know if it was only caused by changing bfloat to float32
Could you explain more about the change? Did you change
examples/loreft/train.py
? And what is the change?For issue 2), i attached my local notebook which does not encounter this issue: main_demo.pdf
Could you check the version of your
transformers
library? Could you install the4.39.3
version and try again? It is mostly likely a env/set-up issue since all my experiments are running just fine.
I did not modify examples/loreft/train.py
. Issue 1 means that I cannot use bfloat16, so I added —dtype float32
in the command line. I want to know why this memory is too large and why bfloat16 cannot be used. My previous transformers
version is 4.40.2. After I changed the transformers
version to 4.39.3, it did not work for me.
@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:
Please go to the logs, and trace out other details.
Thank you for your patient reply. I have understood the actual quantity required for memory, but I need to find out if it is due to bflot16 not being able to be used or if there are other reasons that cause the memory to be too large when running the same command.
I need to find out if it is due to bflot16 not being able to be used or if ..
Hey! yes, i think so. I am running with bf16, and that is probably the reason why my MEM is lower.
@mrsempress what is your torch version?
@mrsempress hey, this is probably an env issue - to resolve this, maybe create a clean conda env, and install packages in the same versions as i have.
here is the requirements.txt
as well as theenvironment.yml
file of my conda env
name: wuzhengx-310
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.12.12=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=3.0.13=h7f8727e_0
- python=3.10.13=h955ad1f_0
- readline=8.2=h5eee18b_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- xz=5.4.5=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- accelerate==0.29.1
- aiofiles==23.2.1
- aiohttp==3.9.3
- aiosignal==1.3.1
- alpaca-eval==0.6
- altair==5.2.0
- annotated-types==0.6.0
- anyio==4.3.0
- appdirs==1.4.4
- argon2-cffi==23.1.0
- argon2-cffi-bindings==21.2.0
- arrow==1.3.0
- asttokens==2.4.1
- async-lru==2.0.4
- async-timeout==4.0.3
- attrs==23.2.0
- babel==2.14.0
- beautifulsoup4==4.12.3
- bitsandbytes==0.42.0
- bleach==6.1.0
- cachetools==5.3.3
- certifi==2024.2.2
- cffi==1.16.0
- charset-normalizer==3.3.2
- click==8.1.7
- colorama==0.4.6
- comm==0.2.1
- contourpy==1.2.0
- cycler==0.12.1
- dacite==1.8.1
- datasets==2.18.0
- debugpy==1.8.1
- decorator==5.1.1
- defusedxml==0.7.1
- diffusers==0.27.2
- dill==0.3.7
- distro==1.9.0
- docker-pycreds==0.4.0
- einops==0.7.0
- evaluate==0.4.1
- exceptiongroup==1.2.0
- executing==2.0.1
- fastapi==0.110.0
- fastjsonschema==2.19.1
- ffmpy==0.3.2
- filelock==3.13.1
- fire==0.5.0
- fonttools==4.49.0
- fqdn==1.5.1
- frozenlist==1.4.1
- fsspec==2024.2.0
- gcsfs==2024.2.0
- gitdb==4.0.11
- gitpython==3.1.42
- google-api-core==2.18.0
- google-auth==2.29.0
- google-auth-oauthlib==1.2.0
- google-cloud-core==2.4.1
- google-cloud-storage==2.16.0
- google-crc32c==1.5.0
- google-resumable-media==2.7.0
- googleapis-common-protos==1.63.0
- gradio==3.50.0
- gradio-client==0.6.1
- h11==0.14.0
- htmlmin==0.1.12
- httpcore==1.0.4
- httpx==0.27.0
- huggingface-hub==0.20.3
- idna==3.6
- imagehash==4.3.1
- importlib-metadata==7.1.0
- importlib-resources==6.1.2
- ipykernel==6.29.3
- ipython==8.22.1
- ipywidgets==8.1.1
- isoduration==20.11.0
- jedi==0.19.1
- jinja2==3.1.3
- joblib==1.3.2
- json5==0.9.17
- jsonpointer==2.4
- jsonschema==4.21.1
- jsonschema-specifications==2023.12.1
- jupyter==1.0.0
- jupyter-client==8.6.0
- jupyter-console==6.6.3
- jupyter-core==5.7.1
- jupyter-events==0.9.0
- jupyter-lsp==2.2.3
- jupyter-server==2.12.5
- jupyter-server-terminals==0.5.2
- jupyterlab==4.1.2
- jupyterlab-pygments==0.3.0
- jupyterlab-server==2.25.3
- jupyterlab-widgets==3.0.10
- kiwisolver==1.4.5
- llvmlite==0.42.0
- markdown-it-py==3.0.0
- markupsafe==2.1.5
- matplotlib==3.7.4
- matplotlib-inline==0.1.6
- mdurl==0.1.2
- mistune==3.0.2
- mizani==0.9.3
- mpmath==1.3.0
- multidict==6.0.5
- multimethod==1.11.2
- multiprocess==0.70.15
- nbclient==0.9.0
- nbconvert==7.16.1
- nbformat==5.9.2
- nest-asyncio==1.6.0
- networkx==3.2.1
- ninja==1.11.1.1
- notebook==7.1.1
- notebook-shim==0.2.4
- numba==0.59.1
- numpy==1.26.4
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu12==2.19.3
- nvidia-nvjitlink-cu12==12.3.101
- nvidia-nvtx-cu12==12.1.105
- oauthlib==3.2.2
- openai==1.12.0
- orjson==3.9.15
- overrides==7.7.0
- packaging==23.2
- pandas==2.2.1
- pandocfilters==1.5.1
- parso==0.8.3
- patsy==0.5.6
- peft==0.11.1
- pexpect==4.9.0
- phik==0.12.4
- pillow==10.2.0
- pip==23.3.1
- platformdirs==4.2.0
- plotnine==0.12.4
- prometheus-client==0.20.0
- prompt-toolkit==3.0.43
- proto-plus==1.23.0
- protobuf==3.20.3
- psutil==5.9.8
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pyarrow==15.0.0
- pyarrow-hotfix==0.6
- pyasn1==0.6.0
- pyasn1-modules==0.4.0
- pycparser==2.21
- pydantic==2.6.2
- pydantic-core==2.16.3
- pydub==0.25.1
- pygments==2.17.2
- pyparsing==3.1.1
- pyreft==0.0.4
- python-dateutil==2.8.2
- python-dotenv==1.0.1
- python-json-logger==2.0.7
- python-multipart==0.0.9
- pytz==2024.1
- pyvene==0.1.2
- pywavelets==1.6.0
- pyyaml==6.0.1
- pyzmq==25.1.2
- qtconsole==5.5.1
- qtpy==2.4.1
- referencing==0.33.0
- reft==0.0.1.dev0
- regex==2023.12.25
- requests==2.31.0
- requests-oauthlib==2.0.0
- responses==0.18.0
- rfc3339-validator==0.1.4
- rfc3986-validator==0.1.1
- rich==13.7.1
- rpds-py==0.18.0
- rsa==4.9
- ruff==0.3.0
- safetensors==0.4.2
- scikit-learn==1.4.1.post1
- scipy==1.11.4
- seaborn==0.12.2
- semantic-version==2.10.0
- send2trash==1.8.2
- sentencepiece==0.1.96
- sentry-sdk==1.40.6
- setproctitle==1.3.3
- setuptools==68.2.2
- shellingham==1.5.4
- six==1.16.0
- smmap==5.0.1
- sniffio==1.3.1
- soupsieve==2.5
- spaces==0.26.0
- stack-data==0.6.3
- starlette==0.36.3
- statsmodels==0.14.1
- sympy==1.12
- termcolor==2.4.0
- terminado==0.18.0
- threadpoolctl==3.3.0
- tiktoken==0.6.0
- tinycss2==1.2.1
- tokenizers==0.15.2
- tomli==2.0.1
- tomlkit==0.12.0
- toolz==0.12.1
- torch==2.2.1
- tornado==6.4
- tqdm==4.66.2
- traitlets==5.14.1
- transformers==4.39.3
- triton==2.2.0
- typeguard==4.2.1
- typer==0.9.0
- types-python-dateutil==2.8.19.20240106
- typing-extensions==4.10.0
- tzdata==2024.1
- uri-template==1.3.0
- urllib3==2.2.1
- uvicorn==0.27.1
- visions==0.7.6
- wandb==0.16.3
- wcwidth==0.2.13
- webcolors==1.13
- webencodings==0.5.1
- websocket-client==1.7.0
- websockets==11.0.3
- wheel==0.41.2
- widgetsnbextension==4.0.10
- wordcloud==1.9.3
- xxhash==3.4.1
- yarl==1.9.4
- ydata-profiling==4.7.0
- zipp==3.18.1
please let me know if the problem still exists. thanks.
I need to find out if it is due to bflot16 not being able to be used or if ..
Hey! yes, i think so. I am running with bf16, and that is probably the reason why my MEM is lower.
Ok~
@mrsempress what is your torch version?
My torch vision is 2.0.1
@mrsempress hey, this is probably an env issue - to resolve this, maybe create a clean conda env, and install packages in the same versions as i have.
here is the
requirements.txt
as well as theenvironment.yml
file of my conda envname: wuzhengx-310 channels: - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - bzip2=1.0.8=h7b6447c_0 - ca-certificates=2023.12.12=h06a4308_0 - ld_impl_linux-64=2.38=h1181459_1 - libffi=3.4.4=h6a678d5_0 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libstdcxx-ng=11.2.0=h1234567_1 - libuuid=1.41.5=h5eee18b_0 - ncurses=6.4=h6a678d5_0 - openssl=3.0.13=h7f8727e_0 - python=3.10.13=h955ad1f_0 - readline=8.2=h5eee18b_0 - sqlite=3.41.2=h5eee18b_0 - tk=8.6.12=h1ccaba5_0 - xz=5.4.5=h5eee18b_0 - zlib=1.2.13=h5eee18b_0 - pip: - accelerate==0.29.1 - aiofiles==23.2.1 - aiohttp==3.9.3 - aiosignal==1.3.1 - alpaca-eval==0.6 - altair==5.2.0 - annotated-types==0.6.0 - anyio==4.3.0 - appdirs==1.4.4 - argon2-cffi==23.1.0 - argon2-cffi-bindings==21.2.0 - arrow==1.3.0 - asttokens==2.4.1 - async-lru==2.0.4 - async-timeout==4.0.3 - attrs==23.2.0 - babel==2.14.0 - beautifulsoup4==4.12.3 - bitsandbytes==0.42.0 - bleach==6.1.0 - cachetools==5.3.3 - certifi==2024.2.2 - cffi==1.16.0 - charset-normalizer==3.3.2 - click==8.1.7 - colorama==0.4.6 - comm==0.2.1 - contourpy==1.2.0 - cycler==0.12.1 - dacite==1.8.1 - datasets==2.18.0 - debugpy==1.8.1 - decorator==5.1.1 - defusedxml==0.7.1 - diffusers==0.27.2 - dill==0.3.7 - distro==1.9.0 - docker-pycreds==0.4.0 - einops==0.7.0 - evaluate==0.4.1 - exceptiongroup==1.2.0 - executing==2.0.1 - fastapi==0.110.0 - fastjsonschema==2.19.1 - ffmpy==0.3.2 - filelock==3.13.1 - fire==0.5.0 - fonttools==4.49.0 - fqdn==1.5.1 - frozenlist==1.4.1 - fsspec==2024.2.0 - gcsfs==2024.2.0 - gitdb==4.0.11 - gitpython==3.1.42 - google-api-core==2.18.0 - google-auth==2.29.0 - google-auth-oauthlib==1.2.0 - google-cloud-core==2.4.1 - google-cloud-storage==2.16.0 - google-crc32c==1.5.0 - google-resumable-media==2.7.0 - googleapis-common-protos==1.63.0 - gradio==3.50.0 - gradio-client==0.6.1 - h11==0.14.0 - htmlmin==0.1.12 - httpcore==1.0.4 - httpx==0.27.0 - huggingface-hub==0.20.3 - idna==3.6 - imagehash==4.3.1 - importlib-metadata==7.1.0 - importlib-resources==6.1.2 - ipykernel==6.29.3 - ipython==8.22.1 - ipywidgets==8.1.1 - isoduration==20.11.0 - jedi==0.19.1 - jinja2==3.1.3 - joblib==1.3.2 - json5==0.9.17 - jsonpointer==2.4 - jsonschema==4.21.1 - jsonschema-specifications==2023.12.1 - jupyter==1.0.0 - jupyter-client==8.6.0 - jupyter-console==6.6.3 - jupyter-core==5.7.1 - jupyter-events==0.9.0 - jupyter-lsp==2.2.3 - jupyter-server==2.12.5 - jupyter-server-terminals==0.5.2 - jupyterlab==4.1.2 - jupyterlab-pygments==0.3.0 - jupyterlab-server==2.25.3 - jupyterlab-widgets==3.0.10 - kiwisolver==1.4.5 - llvmlite==0.42.0 - markdown-it-py==3.0.0 - markupsafe==2.1.5 - matplotlib==3.7.4 - matplotlib-inline==0.1.6 - mdurl==0.1.2 - mistune==3.0.2 - mizani==0.9.3 - mpmath==1.3.0 - multidict==6.0.5 - multimethod==1.11.2 - multiprocess==0.70.15 - nbclient==0.9.0 - nbconvert==7.16.1 - nbformat==5.9.2 - nest-asyncio==1.6.0 - networkx==3.2.1 - ninja==1.11.1.1 - notebook==7.1.1 - notebook-shim==0.2.4 - numba==0.59.1 - numpy==1.26.4 - nvidia-cublas-cu12==12.1.3.1 - nvidia-cuda-cupti-cu12==12.1.105 - nvidia-cuda-nvrtc-cu12==12.1.105 - nvidia-cuda-runtime-cu12==12.1.105 - nvidia-cudnn-cu12==8.9.2.26 - nvidia-cufft-cu12==11.0.2.54 - nvidia-curand-cu12==10.3.2.106 - nvidia-cusolver-cu12==11.4.5.107 - nvidia-cusparse-cu12==12.1.0.106 - nvidia-nccl-cu12==2.19.3 - nvidia-nvjitlink-cu12==12.3.101 - nvidia-nvtx-cu12==12.1.105 - oauthlib==3.2.2 - openai==1.12.0 - orjson==3.9.15 - overrides==7.7.0 - packaging==23.2 - pandas==2.2.1 - pandocfilters==1.5.1 - parso==0.8.3 - patsy==0.5.6 - peft==0.11.1 - pexpect==4.9.0 - phik==0.12.4 - pillow==10.2.0 - pip==23.3.1 - platformdirs==4.2.0 - plotnine==0.12.4 - prometheus-client==0.20.0 - prompt-toolkit==3.0.43 - proto-plus==1.23.0 - protobuf==3.20.3 - psutil==5.9.8 - ptyprocess==0.7.0 - pure-eval==0.2.2 - pyarrow==15.0.0 - pyarrow-hotfix==0.6 - pyasn1==0.6.0 - pyasn1-modules==0.4.0 - pycparser==2.21 - pydantic==2.6.2 - pydantic-core==2.16.3 - pydub==0.25.1 - pygments==2.17.2 - pyparsing==3.1.1 - pyreft==0.0.4 - python-dateutil==2.8.2 - python-dotenv==1.0.1 - python-json-logger==2.0.7 - python-multipart==0.0.9 - pytz==2024.1 - pyvene==0.1.2 - pywavelets==1.6.0 - pyyaml==6.0.1 - pyzmq==25.1.2 - qtconsole==5.5.1 - qtpy==2.4.1 - referencing==0.33.0 - reft==0.0.1.dev0 - regex==2023.12.25 - requests==2.31.0 - requests-oauthlib==2.0.0 - responses==0.18.0 - rfc3339-validator==0.1.4 - rfc3986-validator==0.1.1 - rich==13.7.1 - rpds-py==0.18.0 - rsa==4.9 - ruff==0.3.0 - safetensors==0.4.2 - scikit-learn==1.4.1.post1 - scipy==1.11.4 - seaborn==0.12.2 - semantic-version==2.10.0 - send2trash==1.8.2 - sentencepiece==0.1.96 - sentry-sdk==1.40.6 - setproctitle==1.3.3 - setuptools==68.2.2 - shellingham==1.5.4 - six==1.16.0 - smmap==5.0.1 - sniffio==1.3.1 - soupsieve==2.5 - spaces==0.26.0 - stack-data==0.6.3 - starlette==0.36.3 - statsmodels==0.14.1 - sympy==1.12 - termcolor==2.4.0 - terminado==0.18.0 - threadpoolctl==3.3.0 - tiktoken==0.6.0 - tinycss2==1.2.1 - tokenizers==0.15.2 - tomli==2.0.1 - tomlkit==0.12.0 - toolz==0.12.1 - torch==2.2.1 - tornado==6.4 - tqdm==4.66.2 - traitlets==5.14.1 - transformers==4.39.3 - triton==2.2.0 - typeguard==4.2.1 - typer==0.9.0 - types-python-dateutil==2.8.19.20240106 - typing-extensions==4.10.0 - tzdata==2024.1 - uri-template==1.3.0 - urllib3==2.2.1 - uvicorn==0.27.1 - visions==0.7.6 - wandb==0.16.3 - wcwidth==0.2.13 - webcolors==1.13 - webencodings==0.5.1 - websocket-client==1.7.0 - websockets==11.0.3 - wheel==0.41.2 - widgetsnbextension==4.0.10 - wordcloud==1.9.3 - xxhash==3.4.1 - yarl==1.9.4 - ydata-profiling==4.7.0 - zipp==3.18.1
please let me know if the problem still exists. thanks.
Ok, I will try.
@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:
Please go to the logs, and trace out other details.
After I updated the version to make bfloat16 available, 7B experiments for arithmetic tasks need 52574 GMiB when batch size is 8, but your experiment can be run with 40G A100. That is to say, besides bfloat16, there are other factors that can reduce memory. The command I run uses the one in examples/loreft/README.md, python train.py -task math \ -Data_dir dataset\ -Model yahma/llama-7b-hf\ -Seed 42\ -L all - r 8- p f7+l7- e 12- lr 9e-4\ -Type LoreftIntervention\ -Gradient_cccumulation_steps 2\ -Batch_size 16\ -Eval-batch_size 4\ --Dropout 0.00\ --Test_split test\ --Usenormalized template\ --Share_weights\ --Warmup ratio 0.1\ --Greedy_decoding\ --Save_model
@mrsempress this is expected i think, i am using -gradient_accumulation_steps 8 -batch_size 4
, so if you are running with a batch size of 8, it can be doubled?
this is the screenshot of one of the publicly released run stats: https://wandb.ai/wuzhengx/ReFT_MuadDib_math/runs/xoumltuz/
@mrsempress this is expected i think, i am using
-gradient_accumulation_steps 8 -batch_size 4
, so if you are running with a batch size of 8, it can be doubled?
So, should I use the command -gradient_accumulation_steps 16 -batch_size 8
or -gradient_accumulation_steps 4 -batch_size 8
?
I also want to know how to set hyperparameters like gradient_accumulation_steps
, as we cannot directly follow the command in examples/loreft/README.md
.
@mrsempress this is expected i think, i am using
-gradient_accumulation_steps 8 -batch_size 4
, so if you are running with a batch size of 8, it can be doubled?So, should I use the command
-gradient_accumulation_steps 16 -batch_size 8
or-gradient_accumulation_steps 4 -batch_size 8
?I also want to know how to set hyperparameters like
gradient_accumulation_steps
, as we cannot directly follow the command inexamples/loreft/README.md
.
Sorry about the confusion - but I think there is nothing being changed here, since for hyperparameter, what matters is the effective batch size, which is batch size (bounded by the GPU MEM) times the gradient accumulation step (you can set whatever you want to match the effective batch size). I think for the script you sent earlier and my settings, they all have an effective batch size of 32.
Note in the paper, we only report effective batch size, not per device batch size.
Hope these help. Thanks.
Gradient_cccumulation_steps
Thank you for your reply. Now I understand how to set Gradient_cccumulation_steps and batch size.
Thanks for your wonderful model, but I have got some problems.
- can not use bfloat16.
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1211, in forward outputs = self.model( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 992, in forward causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1095, in _update_causal_mask causal_mask = torch.triu(causal_mask, diagonal=1) RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
- I run the
main_demo.ipynb
, but got the error:Traceback (most recent call last): File "/mnt/geogpt-gpfs/pyreft/inference.py", line 61, in <module> _ = trainer.train() File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train return inner_training_loop( File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2116, in _inner_training_loop self.control = self.callback_handler.on_train_begin(args, self.state, self.control) File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 371, in on_train_begin return self.call_event("on_train_begin", args, state, control) File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 415, in call_event result = getattr(callback, event)( File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 636, in on_train_begin model_config_json = model.config.to_json_string() File "/opt/conda/lib/python3.10/site-packages/transformers/configuration_utils.py", line 938, in to_json_string return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" File "/opt/conda/lib/python3.10/json/__init__.py", line 238, in dumps **kw).encode(obj) File "/opt/conda/lib/python3.10/json/encoder.py", line 201, in encode chunks = list(chunks) File "/opt/conda/lib/python3.10/json/encoder.py", line 431, in _iterencode yield from _iterencode_dict(o, _current_indent_level) File "/opt/conda/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict yield from chunks File "/opt/conda/lib/python3.10/json/encoder.py", line 325, in _iterencode_list yield from chunks File "/opt/conda/lib/python3.10/json/encoder.py", line 438, in _iterencode o = _default(o) File "/opt/conda/lib/python3.10/json/encoder.py", line 179, in default raise TypeError(f'Object of type {o.__class__.__name__} ' TypeError: Object of type type is not JSON serializable
I find the issues 69, but I use the
main_demo.ipynb
, so it does not work for me.
For problem 1, I upgrade to torch-2.3.1, it works.
Thanks for your wonderful model, but I have got some problems.
can not use bfloat16.
I run the
main_demo.ipynb
, but got the error:I find the issues 69, but I use the
main_demo.ipynb
, so it does not work for me.