stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
1.12k stars 93 forks source link

[P1] [Error] can not use bfloat16 and TypeError: Object of type type is not JSON serializable #102

Closed mrsempress closed 3 months ago

mrsempress commented 3 months ago

Thanks for your wonderful model, but I have got some problems.

  1. can not use bfloat16.

    File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1211, in forward
    outputs = self.model(
    File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 992, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
    File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1095, in _update_causal_mask
    causal_mask = torch.triu(causal_mask, diagonal=1)
    RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
  2. I run the main_demo.ipynb, but got the error:

    Traceback (most recent call last):                                                                                                                                                
    File "/mnt/geogpt-gpfs/pyreft/inference.py", line 61, in <module>                                                                  
    _ = trainer.train()                                                                                                                                                           
    File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train                                                                                     
    return inner_training_loop(                                                                                                                                                   
    File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2116, in _inner_training_loop                                                                      
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)                                                                                           
    File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 371, in on_train_begin                                                                    
    return self.call_event("on_train_begin", args, state, control)                                                                                                                
    File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 415, in call_event                                                                        
    result = getattr(callback, event)(                                                                                                                                            
    File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 636, in on_train_begin                                                      
    model_config_json = model.config.to_json_string()                                                                                                                             
    File "/opt/conda/lib/python3.10/site-packages/transformers/configuration_utils.py", line 938, in to_json_string                                                                 
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
    File "/opt/conda/lib/python3.10/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
    File "/opt/conda/lib/python3.10/json/encoder.py", line 201, in encode
    chunks = list(chunks)
    File "/opt/conda/lib/python3.10/json/encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
    File "/opt/conda/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
    File "/opt/conda/lib/python3.10/json/encoder.py", line 325, in _iterencode_list
    yield from chunks
    File "/opt/conda/lib/python3.10/json/encoder.py", line 438, in _iterencode
    o = _default(o)
    File "/opt/conda/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
    TypeError: Object of type type is not JSON serializable

    I find the issues 69, but I use the main_demo.ipynb, so it does not work for me.

frankaging commented 3 months ago

@mrsempress hey, thanks for raising the issue.

on the second problem: the root cause is probably the one identified in https://github.com/stanfordnlp/pyreft/issues/70 -- tensorboard is not well integrated yet.

as a result, you need to make sure to run your commend with --report_to none, or --report_to wandb.

frankaging commented 3 months ago

@mrsempress on the first problem, could you make sure there is only 1 GPU visible on your machine (i know you ran with CUDA_VISIABLE_DEVICES=6 but this problem usually occur when there are multiple GPUs); our script does not support multi-gpu training well at this point yet.

mrsempress commented 3 months ago

Sorry, the first issue has been updated. The first issue is that bfloat16 cannot be used. For issue 2, I did not see the appearance of tensorboard in main_demo.ipynb, and there is no argparse, so I do not understand why.

frankaging commented 3 months ago

@mrsempress thanks.

For issue 1) could you provide your running script?

For issue 2) could you reproduce this error by running the notebook on google colab, and share the error'd colab with me?

Thanks! These will help me to root cause the issues here.

mrsempress commented 3 months ago

For issue 1) CUDA_VISIBLE_DEVICES=0 python examples/loreft/train.py -task gsm8k -model ../../models/Llama-7b-hf -seed 42 -l all -r 4 -p f7+l7 -e 12 -lr 9e-4 -type NodireftIntervention -gradient_accumulation_steps 4 -batch_size 8 -eval_batch_size 4 --dropout 0.05 --test_split validation --use_normalized_template --greedy_decoding --warmup_ratio 0.00 --weight_decay 0.06 Additionally, I would like to know how much memory is occupied by hyperparameter tuning and training in Loreft. Because I use hyperparameter tuning, it takes up over 60 GB of memory. I want to know if it was only caused by changing bfloat to float32. Also, I would like to know how long the training time is usually?

For issue 2), as I reproduce it successfully, the link will be updated. Now when installing pyreft, Colab will prompt "you must restart the runtime in order to use newly installed versions", which will take some time. I only used the original ipynb without modifying the code, so you can also try the experiment. I am not sure if it is due to machine environment issues.

frankaging commented 3 months ago

@mrsempress Thanks.

I want to know if it was only caused by changing bfloat to float32

Could you explain more about the change? Did you change examples/loreft/train.py? And what is the change?

For issue 2), i attached my local notebook which does not encounter this issue: main_demo.pdf

Could you check the version of your transformers library? Could you install 4.39.3 version and try again? It is mostly likely a env/set-up issue since all my experiments are running just fine.

frankaging commented 3 months ago

@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:

Screenshot 2024-06-06 at 3 01 01 PM

Please go to the logs, and trace out other details.

mrsempress commented 3 months ago

@mrsempress Thanks.

I want to know if it was only caused by changing bfloat to float32

Could you explain more about the change? Did you change examples/loreft/train.py? And what is the change?

For issue 2), i attached my local notebook which does not encounter this issue: main_demo.pdf

Could you check the version of your transformers library? Could you install the 4.39.3 version and try again? It is mostly likely a env/set-up issue since all my experiments are running just fine.

I did not modify examples/loreft/train.py. Issue 1 means that I cannot use bfloat16, so I added —dtype float32 in the command line. I want to know why this memory is too large and why bfloat16 cannot be used. My previous transformers version is 4.40.2. After I changed the transformers version to 4.39.3, it did not work for me.

mrsempress commented 3 months ago

@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:

Screenshot 2024-06-06 at 3 01 01 PM

Please go to the logs, and trace out other details.

Thank you for your patient reply. I have understood the actual quantity required for memory, but I need to find out if it is due to bflot16 not being able to be used or if there are other reasons that cause the memory to be too large when running the same command.

frankaging commented 3 months ago

I need to find out if it is due to bflot16 not being able to be used or if ..

Hey! yes, i think so. I am running with bf16, and that is probably the reason why my MEM is lower.

frankaging commented 3 months ago

@mrsempress what is your torch version?

frankaging commented 3 months ago

@mrsempress hey, this is probably an env issue - to resolve this, maybe create a clean conda env, and install packages in the same versions as i have.

here is the requirements.txt as well as theenvironment.yml file of my conda env

requirements.txt

name: wuzhengx-310
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.12.12=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.13=h7f8727e_0
  - python=3.10.13=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - xz=5.4.5=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
    - accelerate==0.29.1
    - aiofiles==23.2.1
    - aiohttp==3.9.3
    - aiosignal==1.3.1
    - alpaca-eval==0.6
    - altair==5.2.0
    - annotated-types==0.6.0
    - anyio==4.3.0
    - appdirs==1.4.4
    - argon2-cffi==23.1.0
    - argon2-cffi-bindings==21.2.0
    - arrow==1.3.0
    - asttokens==2.4.1
    - async-lru==2.0.4
    - async-timeout==4.0.3
    - attrs==23.2.0
    - babel==2.14.0
    - beautifulsoup4==4.12.3
    - bitsandbytes==0.42.0
    - bleach==6.1.0
    - cachetools==5.3.3
    - certifi==2024.2.2
    - cffi==1.16.0
    - charset-normalizer==3.3.2
    - click==8.1.7
    - colorama==0.4.6
    - comm==0.2.1
    - contourpy==1.2.0
    - cycler==0.12.1
    - dacite==1.8.1
    - datasets==2.18.0
    - debugpy==1.8.1
    - decorator==5.1.1
    - defusedxml==0.7.1
    - diffusers==0.27.2
    - dill==0.3.7
    - distro==1.9.0
    - docker-pycreds==0.4.0
    - einops==0.7.0
    - evaluate==0.4.1
    - exceptiongroup==1.2.0
    - executing==2.0.1
    - fastapi==0.110.0
    - fastjsonschema==2.19.1
    - ffmpy==0.3.2
    - filelock==3.13.1
    - fire==0.5.0
    - fonttools==4.49.0
    - fqdn==1.5.1
    - frozenlist==1.4.1
    - fsspec==2024.2.0
    - gcsfs==2024.2.0
    - gitdb==4.0.11
    - gitpython==3.1.42
    - google-api-core==2.18.0
    - google-auth==2.29.0
    - google-auth-oauthlib==1.2.0
    - google-cloud-core==2.4.1
    - google-cloud-storage==2.16.0
    - google-crc32c==1.5.0
    - google-resumable-media==2.7.0
    - googleapis-common-protos==1.63.0
    - gradio==3.50.0
    - gradio-client==0.6.1
    - h11==0.14.0
    - htmlmin==0.1.12
    - httpcore==1.0.4
    - httpx==0.27.0
    - huggingface-hub==0.20.3
    - idna==3.6
    - imagehash==4.3.1
    - importlib-metadata==7.1.0
    - importlib-resources==6.1.2
    - ipykernel==6.29.3
    - ipython==8.22.1
    - ipywidgets==8.1.1
    - isoduration==20.11.0
    - jedi==0.19.1
    - jinja2==3.1.3
    - joblib==1.3.2
    - json5==0.9.17
    - jsonpointer==2.4
    - jsonschema==4.21.1
    - jsonschema-specifications==2023.12.1
    - jupyter==1.0.0
    - jupyter-client==8.6.0
    - jupyter-console==6.6.3
    - jupyter-core==5.7.1
    - jupyter-events==0.9.0
    - jupyter-lsp==2.2.3
    - jupyter-server==2.12.5
    - jupyter-server-terminals==0.5.2
    - jupyterlab==4.1.2
    - jupyterlab-pygments==0.3.0
    - jupyterlab-server==2.25.3
    - jupyterlab-widgets==3.0.10
    - kiwisolver==1.4.5
    - llvmlite==0.42.0
    - markdown-it-py==3.0.0
    - markupsafe==2.1.5
    - matplotlib==3.7.4
    - matplotlib-inline==0.1.6
    - mdurl==0.1.2
    - mistune==3.0.2
    - mizani==0.9.3
    - mpmath==1.3.0
    - multidict==6.0.5
    - multimethod==1.11.2
    - multiprocess==0.70.15
    - nbclient==0.9.0
    - nbconvert==7.16.1
    - nbformat==5.9.2
    - nest-asyncio==1.6.0
    - networkx==3.2.1
    - ninja==1.11.1.1
    - notebook==7.1.1
    - notebook-shim==0.2.4
    - numba==0.59.1
    - numpy==1.26.4
    - nvidia-cublas-cu12==12.1.3.1
    - nvidia-cuda-cupti-cu12==12.1.105
    - nvidia-cuda-nvrtc-cu12==12.1.105
    - nvidia-cuda-runtime-cu12==12.1.105
    - nvidia-cudnn-cu12==8.9.2.26
    - nvidia-cufft-cu12==11.0.2.54
    - nvidia-curand-cu12==10.3.2.106
    - nvidia-cusolver-cu12==11.4.5.107
    - nvidia-cusparse-cu12==12.1.0.106
    - nvidia-nccl-cu12==2.19.3
    - nvidia-nvjitlink-cu12==12.3.101
    - nvidia-nvtx-cu12==12.1.105
    - oauthlib==3.2.2
    - openai==1.12.0
    - orjson==3.9.15
    - overrides==7.7.0
    - packaging==23.2
    - pandas==2.2.1
    - pandocfilters==1.5.1
    - parso==0.8.3
    - patsy==0.5.6
    - peft==0.11.1
    - pexpect==4.9.0
    - phik==0.12.4
    - pillow==10.2.0
    - pip==23.3.1
    - platformdirs==4.2.0
    - plotnine==0.12.4
    - prometheus-client==0.20.0
    - prompt-toolkit==3.0.43
    - proto-plus==1.23.0
    - protobuf==3.20.3
    - psutil==5.9.8
    - ptyprocess==0.7.0
    - pure-eval==0.2.2
    - pyarrow==15.0.0
    - pyarrow-hotfix==0.6
    - pyasn1==0.6.0
    - pyasn1-modules==0.4.0
    - pycparser==2.21
    - pydantic==2.6.2
    - pydantic-core==2.16.3
    - pydub==0.25.1
    - pygments==2.17.2
    - pyparsing==3.1.1
    - pyreft==0.0.4
    - python-dateutil==2.8.2
    - python-dotenv==1.0.1
    - python-json-logger==2.0.7
    - python-multipart==0.0.9
    - pytz==2024.1
    - pyvene==0.1.2
    - pywavelets==1.6.0
    - pyyaml==6.0.1
    - pyzmq==25.1.2
    - qtconsole==5.5.1
    - qtpy==2.4.1
    - referencing==0.33.0
    - reft==0.0.1.dev0
    - regex==2023.12.25
    - requests==2.31.0
    - requests-oauthlib==2.0.0
    - responses==0.18.0
    - rfc3339-validator==0.1.4
    - rfc3986-validator==0.1.1
    - rich==13.7.1
    - rpds-py==0.18.0
    - rsa==4.9
    - ruff==0.3.0
    - safetensors==0.4.2
    - scikit-learn==1.4.1.post1
    - scipy==1.11.4
    - seaborn==0.12.2
    - semantic-version==2.10.0
    - send2trash==1.8.2
    - sentencepiece==0.1.96
    - sentry-sdk==1.40.6
    - setproctitle==1.3.3
    - setuptools==68.2.2
    - shellingham==1.5.4
    - six==1.16.0
    - smmap==5.0.1
    - sniffio==1.3.1
    - soupsieve==2.5
    - spaces==0.26.0
    - stack-data==0.6.3
    - starlette==0.36.3
    - statsmodels==0.14.1
    - sympy==1.12
    - termcolor==2.4.0
    - terminado==0.18.0
    - threadpoolctl==3.3.0
    - tiktoken==0.6.0
    - tinycss2==1.2.1
    - tokenizers==0.15.2
    - tomli==2.0.1
    - tomlkit==0.12.0
    - toolz==0.12.1
    - torch==2.2.1
    - tornado==6.4
    - tqdm==4.66.2
    - traitlets==5.14.1
    - transformers==4.39.3
    - triton==2.2.0
    - typeguard==4.2.1
    - typer==0.9.0
    - types-python-dateutil==2.8.19.20240106
    - typing-extensions==4.10.0
    - tzdata==2024.1
    - uri-template==1.3.0
    - urllib3==2.2.1
    - uvicorn==0.27.1
    - visions==0.7.6
    - wandb==0.16.3
    - wcwidth==0.2.13
    - webcolors==1.13
    - webencodings==0.5.1
    - websocket-client==1.7.0
    - websockets==11.0.3
    - wheel==0.41.2
    - widgetsnbextension==4.0.10
    - wordcloud==1.9.3
    - xxhash==3.4.1
    - yarl==1.9.4
    - ydata-profiling==4.7.0
    - zipp==3.18.1

please let me know if the problem still exists. thanks.

mrsempress commented 3 months ago

I need to find out if it is due to bflot16 not being able to be used or if ..

Hey! yes, i think so. I am running with bf16, and that is probably the reason why my MEM is lower.

Ok~

mrsempress commented 3 months ago

@mrsempress what is your torch version?

My torch vision is 2.0.1

mrsempress commented 3 months ago

@mrsempress hey, this is probably an env issue - to resolve this, maybe create a clean conda env, and install packages in the same versions as i have.

here is the requirements.txt as well as theenvironment.yml file of my conda env

requirements.txt

name: wuzhengx-310
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.12.12=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.13=h7f8727e_0
  - python=3.10.13=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - xz=5.4.5=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
    - accelerate==0.29.1
    - aiofiles==23.2.1
    - aiohttp==3.9.3
    - aiosignal==1.3.1
    - alpaca-eval==0.6
    - altair==5.2.0
    - annotated-types==0.6.0
    - anyio==4.3.0
    - appdirs==1.4.4
    - argon2-cffi==23.1.0
    - argon2-cffi-bindings==21.2.0
    - arrow==1.3.0
    - asttokens==2.4.1
    - async-lru==2.0.4
    - async-timeout==4.0.3
    - attrs==23.2.0
    - babel==2.14.0
    - beautifulsoup4==4.12.3
    - bitsandbytes==0.42.0
    - bleach==6.1.0
    - cachetools==5.3.3
    - certifi==2024.2.2
    - cffi==1.16.0
    - charset-normalizer==3.3.2
    - click==8.1.7
    - colorama==0.4.6
    - comm==0.2.1
    - contourpy==1.2.0
    - cycler==0.12.1
    - dacite==1.8.1
    - datasets==2.18.0
    - debugpy==1.8.1
    - decorator==5.1.1
    - defusedxml==0.7.1
    - diffusers==0.27.2
    - dill==0.3.7
    - distro==1.9.0
    - docker-pycreds==0.4.0
    - einops==0.7.0
    - evaluate==0.4.1
    - exceptiongroup==1.2.0
    - executing==2.0.1
    - fastapi==0.110.0
    - fastjsonschema==2.19.1
    - ffmpy==0.3.2
    - filelock==3.13.1
    - fire==0.5.0
    - fonttools==4.49.0
    - fqdn==1.5.1
    - frozenlist==1.4.1
    - fsspec==2024.2.0
    - gcsfs==2024.2.0
    - gitdb==4.0.11
    - gitpython==3.1.42
    - google-api-core==2.18.0
    - google-auth==2.29.0
    - google-auth-oauthlib==1.2.0
    - google-cloud-core==2.4.1
    - google-cloud-storage==2.16.0
    - google-crc32c==1.5.0
    - google-resumable-media==2.7.0
    - googleapis-common-protos==1.63.0
    - gradio==3.50.0
    - gradio-client==0.6.1
    - h11==0.14.0
    - htmlmin==0.1.12
    - httpcore==1.0.4
    - httpx==0.27.0
    - huggingface-hub==0.20.3
    - idna==3.6
    - imagehash==4.3.1
    - importlib-metadata==7.1.0
    - importlib-resources==6.1.2
    - ipykernel==6.29.3
    - ipython==8.22.1
    - ipywidgets==8.1.1
    - isoduration==20.11.0
    - jedi==0.19.1
    - jinja2==3.1.3
    - joblib==1.3.2
    - json5==0.9.17
    - jsonpointer==2.4
    - jsonschema==4.21.1
    - jsonschema-specifications==2023.12.1
    - jupyter==1.0.0
    - jupyter-client==8.6.0
    - jupyter-console==6.6.3
    - jupyter-core==5.7.1
    - jupyter-events==0.9.0
    - jupyter-lsp==2.2.3
    - jupyter-server==2.12.5
    - jupyter-server-terminals==0.5.2
    - jupyterlab==4.1.2
    - jupyterlab-pygments==0.3.0
    - jupyterlab-server==2.25.3
    - jupyterlab-widgets==3.0.10
    - kiwisolver==1.4.5
    - llvmlite==0.42.0
    - markdown-it-py==3.0.0
    - markupsafe==2.1.5
    - matplotlib==3.7.4
    - matplotlib-inline==0.1.6
    - mdurl==0.1.2
    - mistune==3.0.2
    - mizani==0.9.3
    - mpmath==1.3.0
    - multidict==6.0.5
    - multimethod==1.11.2
    - multiprocess==0.70.15
    - nbclient==0.9.0
    - nbconvert==7.16.1
    - nbformat==5.9.2
    - nest-asyncio==1.6.0
    - networkx==3.2.1
    - ninja==1.11.1.1
    - notebook==7.1.1
    - notebook-shim==0.2.4
    - numba==0.59.1
    - numpy==1.26.4
    - nvidia-cublas-cu12==12.1.3.1
    - nvidia-cuda-cupti-cu12==12.1.105
    - nvidia-cuda-nvrtc-cu12==12.1.105
    - nvidia-cuda-runtime-cu12==12.1.105
    - nvidia-cudnn-cu12==8.9.2.26
    - nvidia-cufft-cu12==11.0.2.54
    - nvidia-curand-cu12==10.3.2.106
    - nvidia-cusolver-cu12==11.4.5.107
    - nvidia-cusparse-cu12==12.1.0.106
    - nvidia-nccl-cu12==2.19.3
    - nvidia-nvjitlink-cu12==12.3.101
    - nvidia-nvtx-cu12==12.1.105
    - oauthlib==3.2.2
    - openai==1.12.0
    - orjson==3.9.15
    - overrides==7.7.0
    - packaging==23.2
    - pandas==2.2.1
    - pandocfilters==1.5.1
    - parso==0.8.3
    - patsy==0.5.6
    - peft==0.11.1
    - pexpect==4.9.0
    - phik==0.12.4
    - pillow==10.2.0
    - pip==23.3.1
    - platformdirs==4.2.0
    - plotnine==0.12.4
    - prometheus-client==0.20.0
    - prompt-toolkit==3.0.43
    - proto-plus==1.23.0
    - protobuf==3.20.3
    - psutil==5.9.8
    - ptyprocess==0.7.0
    - pure-eval==0.2.2
    - pyarrow==15.0.0
    - pyarrow-hotfix==0.6
    - pyasn1==0.6.0
    - pyasn1-modules==0.4.0
    - pycparser==2.21
    - pydantic==2.6.2
    - pydantic-core==2.16.3
    - pydub==0.25.1
    - pygments==2.17.2
    - pyparsing==3.1.1
    - pyreft==0.0.4
    - python-dateutil==2.8.2
    - python-dotenv==1.0.1
    - python-json-logger==2.0.7
    - python-multipart==0.0.9
    - pytz==2024.1
    - pyvene==0.1.2
    - pywavelets==1.6.0
    - pyyaml==6.0.1
    - pyzmq==25.1.2
    - qtconsole==5.5.1
    - qtpy==2.4.1
    - referencing==0.33.0
    - reft==0.0.1.dev0
    - regex==2023.12.25
    - requests==2.31.0
    - requests-oauthlib==2.0.0
    - responses==0.18.0
    - rfc3339-validator==0.1.4
    - rfc3986-validator==0.1.1
    - rich==13.7.1
    - rpds-py==0.18.0
    - rsa==4.9
    - ruff==0.3.0
    - safetensors==0.4.2
    - scikit-learn==1.4.1.post1
    - scipy==1.11.4
    - seaborn==0.12.2
    - semantic-version==2.10.0
    - send2trash==1.8.2
    - sentencepiece==0.1.96
    - sentry-sdk==1.40.6
    - setproctitle==1.3.3
    - setuptools==68.2.2
    - shellingham==1.5.4
    - six==1.16.0
    - smmap==5.0.1
    - sniffio==1.3.1
    - soupsieve==2.5
    - spaces==0.26.0
    - stack-data==0.6.3
    - starlette==0.36.3
    - statsmodels==0.14.1
    - sympy==1.12
    - termcolor==2.4.0
    - terminado==0.18.0
    - threadpoolctl==3.3.0
    - tiktoken==0.6.0
    - tinycss2==1.2.1
    - tokenizers==0.15.2
    - tomli==2.0.1
    - tomlkit==0.12.0
    - toolz==0.12.1
    - torch==2.2.1
    - tornado==6.4
    - tqdm==4.66.2
    - traitlets==5.14.1
    - transformers==4.39.3
    - triton==2.2.0
    - typeguard==4.2.1
    - typer==0.9.0
    - types-python-dateutil==2.8.19.20240106
    - typing-extensions==4.10.0
    - tzdata==2024.1
    - uri-template==1.3.0
    - urllib3==2.2.1
    - uvicorn==0.27.1
    - visions==0.7.6
    - wandb==0.16.3
    - wcwidth==0.2.13
    - webcolors==1.13
    - webencodings==0.5.1
    - websocket-client==1.7.0
    - websockets==11.0.3
    - wheel==0.41.2
    - widgetsnbextension==4.0.10
    - wordcloud==1.9.3
    - xxhash==3.4.1
    - yarl==1.9.4
    - ydata-profiling==4.7.0
    - zipp==3.18.1

please let me know if the problem still exists. thanks.

Ok, I will try.

mrsempress commented 3 months ago

@mrsempress minor: in terms of memory profile, you could check our publicly released log from wandb. This is for our arithmetic benchmarks; 7B experiments are ran on 40G A100. I also attached Process GPU Memory Allocated (%) here:

Screenshot 2024-06-06 at 3 01 01 PM

Please go to the logs, and trace out other details.

After I updated the version to make bfloat16 available, 7B experiments for arithmetic tasks need 52574 GMiB when batch size is 8, but your experiment can be run with 40G A100. That is to say, besides bfloat16, there are other factors that can reduce memory. The command I run uses the one in examples/loreft/README.md, python train.py -task math \ -Data_dir dataset\ -Model yahma/llama-7b-hf\ -Seed 42\ -L all - r 8- p f7+l7- e 12- lr 9e-4\ -Type LoreftIntervention\ -Gradient_cccumulation_steps 2\ -Batch_size 16\ -Eval-batch_size 4\ --Dropout 0.00\ --Test_split test\ --Usenormalized template\ --Share_weights\ --Warmup ratio 0.1\ --Greedy_decoding\ --Save_model

frankaging commented 3 months ago

@mrsempress this is expected i think, i am using -gradient_accumulation_steps 8 -batch_size 4, so if you are running with a batch size of 8, it can be doubled?

frankaging commented 3 months ago

this is the screenshot of one of the publicly released run stats: https://wandb.ai/wuzhengx/ReFT_MuadDib_math/runs/xoumltuz/

Screenshot 2024-06-12 at 1 43 33 PM
mrsempress commented 3 months ago

@mrsempress this is expected i think, i am using -gradient_accumulation_steps 8 -batch_size 4, so if you are running with a batch size of 8, it can be doubled?

So, should I use the command -gradient_accumulation_steps 16 -batch_size 8 or -gradient_accumulation_steps 4 -batch_size 8? I also want to know how to set hyperparameters like gradient_accumulation_steps, as we cannot directly follow the command in examples/loreft/README.md.

frankaging commented 3 months ago

@mrsempress this is expected i think, i am using -gradient_accumulation_steps 8 -batch_size 4, so if you are running with a batch size of 8, it can be doubled?

So, should I use the command -gradient_accumulation_steps 16 -batch_size 8 or -gradient_accumulation_steps 4 -batch_size 8?

I also want to know how to set hyperparameters like gradient_accumulation_steps, as we cannot directly follow the command in examples/loreft/README.md.

Sorry about the confusion - but I think there is nothing being changed here, since for hyperparameter, what matters is the effective batch size, which is batch size (bounded by the GPU MEM) times the gradient accumulation step (you can set whatever you want to match the effective batch size). I think for the script you sent earlier and my settings, they all have an effective batch size of 32.

Note in the paper, we only report effective batch size, not per device batch size.

Hope these help. Thanks.

mrsempress commented 3 months ago

Gradient_cccumulation_steps

Thank you for your reply. Now I understand how to set Gradient_cccumulation_steps and batch size.

johnson7788 commented 2 months ago

Thanks for your wonderful model, but I have got some problems.

  1. can not use bfloat16.
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1211, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 992, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1095, in _update_causal_mask
    causal_mask = torch.triu(causal_mask, diagonal=1)
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
  1. I run the main_demo.ipynb, but got the error:
Traceback (most recent call last):                                                                                                                                                
  File "/mnt/geogpt-gpfs/pyreft/inference.py", line 61, in <module>                                                                  
    _ = trainer.train()                                                                                                                                                           
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train                                                                                     
    return inner_training_loop(                                                                                                                                                   
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2116, in _inner_training_loop                                                                      
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)                                                                                           
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 371, in on_train_begin                                                                    
    return self.call_event("on_train_begin", args, state, control)                                                                                                                
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 415, in call_event                                                                        
    result = getattr(callback, event)(                                                                                                                                            
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 636, in on_train_begin                                                      
    model_config_json = model.config.to_json_string()                                                                                                                             
  File "/opt/conda/lib/python3.10/site-packages/transformers/configuration_utils.py", line 938, in to_json_string                                                                 
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  File "/opt/conda/lib/python3.10/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 201, in encode
    chunks = list(chunks)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/opt/conda/lib/python3.10/json/encoder.py", line 325, in _iterencode_list
    yield from chunks
  File "/opt/conda/lib/python3.10/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/opt/conda/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type type is not JSON serializable

I find the issues 69, but I use the main_demo.ipynb, so it does not work for me.

For problem 1, I upgrade to torch-2.3.1, it works.