yaojin17 / Unlearning_LLM

[ACL 2024] Code and data for "Machine Unlearning of Pre-trained Large Language Models"
MIT License
43 stars 1 forks source link

The new version of Pytorch needs to specify the tensor location #7

Open qzc438 opened 1 month ago

qzc438 commented 1 month ago

As the title described, please also review the attached error information. RuntimeError: Expected all tensors to be on the same device, but found at least two devices (when checking argument for argument index in method wrapper_CUDA__index_select)

yaojin17 commented 3 weeks ago

Apologies for the late response. Has the issue been resolved? I'm using torch==2.1.2+cu118on my end.

qzc438 commented 3 weeks ago

Sorry, I still cannot resolve this issue.

yaojin17 commented 3 weeks ago
name: base
channels:
  - conda-forge
  - https://repo.anaconda.com/pkgs/main
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - asttokens=2.4.1=pyhd8ed1ab_0
  - brotlipy=0.7.0=py310h7f8727e_1002
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2024.2.2=hbcca054_0
  - certifi=2024.2.2=pyhd8ed1ab_0
  - cffi=1.15.1=py310h5eee18b_3
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - comm=0.2.1=pyhd8ed1ab_0
  - conda=23.1.0=py310h06a4308_0
  - conda-content-trust=0.1.3=py310h06a4308_0
  - conda-package-handling=2.0.2=py310h06a4308_0
  - conda-package-streaming=0.7.0=py310h06a4308_0
  - cryptography=38.0.4=py310h9ce1e76_0
  - debugpy=1.6.7=py310h6a678d5_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - entrypoints=0.4=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - executing=2.0.1=pyhd8ed1ab_0
  - idna=3.4=py310h06a4308_0
  - ipykernel=6.29.0=pyhd33586a_0
  - ipython=8.21.0=pyh707e725_0
  - jedi=0.19.1=pyhd8ed1ab_0
  - jupyter_client=7.3.4=pyhd8ed1ab_0
  - jupyter_core=5.7.1=py310hff52083_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.2=h6a678d5_6
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libsodium=1.0.18=h36c2ea0_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - ncurses=6.4=h6a678d5_0
  - nest-asyncio=1.6.0=pyhd8ed1ab_0
  - openssl=1.1.1w=h7f8727e_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pexpect=4.9.0=pyhd8ed1ab_0
  - pickleshare=0.7.5=py_1003
  - platformdirs=4.2.0=pyhd8ed1ab_0
  - pluggy=1.0.0=py310h06a4308_1
  - prompt-toolkit=3.0.42=pyha770c72_0
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pycosat=0.6.4=py310h5eee18b_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pygments=2.17.2=pyhd8ed1ab_0
  - pyopenssl=22.0.0=pyhd3eb1b0_0
  - pysocks=1.7.1=py310h06a4308_0
  - python=3.10.9=h7a1cb2a_0
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.10=2_cp310
  - pyzmq=25.1.2=py310h6a678d5_0
  - readline=8.2=h5eee18b_0
  - requests=2.28.1=py310h06a4308_0
  - ruamel.yaml=0.17.21=py310h5eee18b_0
  - ruamel.yaml.clib=0.2.6=py310h5eee18b_1
  - setuptools=65.6.3=py310h06a4308_0
  - six=1.16.0=pyhd3eb1b0_1
  - sqlite=3.40.1=h5082296_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tk=8.6.12=h1ccaba5_0
  - toolz=0.12.0=py310h06a4308_0
  - tornado=6.1=py310h5764c6d_3
  - tqdm=4.64.1=py310h06a4308_0
  - traitlets=5.14.1=pyhd8ed1ab_0
  - typing_extensions=4.9.0=pyha770c72_0
  - urllib3=1.26.14=py310h06a4308_0
  - wcwidth=0.2.13=pyhd8ed1ab_0
  - wheel=0.37.1=pyhd3eb1b0_0
  - xz=5.2.10=h5eee18b_1
  - zeromq=4.3.5=h6a678d5_0
  - zlib=1.2.13=h5eee18b_0
  - zstandard=0.18.0=py310h5eee18b_0
  - pip:
      - absl-py==1.4.0
      - accelerate==0.26.1
      - aiofiles==23.2.1
      - aiohttp==3.9.1
      - aioprometheus==23.12.0
      - aiosignal==1.3.1
      - altair==5.2.0
      - annotated-types==0.6.0
      - antlr4-python3-runtime==4.9.3
      - anyio==4.2.0
      - apex==0.1
      - appdirs==1.4.4
      - async-timeout==4.0.3
      - attrs==23.2.0
      - boto3==1.26.121
      - botocore==1.29.121
      - cachetools==5.3.0
      - chardet==5.2.0
      - click==8.1.3
      - cmake==3.26.3
      - colorama==0.4.6
      - contourpy==1.2.0
      - cycler==0.12.1
      - dataproperty==1.0.1
      - datasets==2.17.1
      - deepspeed==0.12.0
      - dill==0.3.7
      - docker-pycreds==0.4.0
      - einops==0.7.0
      - evaluate==0.4.1
      - fairscale==0.4.13
      - fastapi==0.108.0
      - ffmpy==0.3.1
      - filelock==3.12.0
      - flash-attn==2.4.2
      - fonttools==4.47.2
      - frozenlist==1.4.1
      - fsspec==2023.10.0
      - fuzzywuzzy==0.18.0
      - gauge==0.1.2
      - gitdb==4.0.10
      - gitpython==3.1.31
      - google-auth==2.17.3
      - google-auth-oauthlib==1.0.0
      - gradio==4.14.0
      - gradio-client==0.8.0
      - greenlet==3.0.3
      - grpcio==1.54.0
      - h11==0.14.0
      - hjson==3.1.0
      - httpcore==1.0.2
      - httptools==0.6.1
      - httpx==0.26.0
      - huggingface-hub==0.20.3
      - hydra-core==1.3.2
      - importlib-resources==6.1.1
      - jieba==0.42.1
      - jinja2==3.1.2
      - jmespath==1.0.1
      - joblib==1.2.0
      - jsonline==0.2.1
      - jsonlines==4.0.0
      - jsonschema==4.20.0
      - jsonschema-specifications==2023.12.1
      - kiwisolver==1.4.5
      - lit==16.0.2
      - lm-eval==1.0.0
      - markdown==3.4.3
      - markdown-it-py==3.0.0
      - markupsafe==2.1.2
      - matplotlib==3.8.2
      - mbstrdecoder==1.1.3
      - mdurl==0.1.2
      - megablocks==0.5.0
      - megatron-core==0.4.0
      - mpmath==1.3.0
      - msgpack==1.0.7
      - multidict==6.0.4
      - multiprocess==0.70.15
      - networkx==3.1
      - ninja==1.11.1
      - nltk==3.8.1
      - numexpr==2.8.8
      - numpy==1.24.3
      - nvidia-cublas-cu12==12.1.3.1
      - nvidia-cuda-cupti-cu12==12.1.105
      - nvidia-cuda-nvrtc-cu12==12.1.105
      - nvidia-cuda-runtime-cu12==12.1.105
      - nvidia-cudnn-cu12==8.9.2.26
      - nvidia-cufft-cu12==11.0.2.54
      - nvidia-curand-cu12==10.3.2.106
      - nvidia-cusolver-cu12==11.4.5.107
      - nvidia-cusparse-cu12==12.1.0.106
      - nvidia-nccl-cu12==2.18.1
      - nvidia-nvjitlink-cu12==12.3.101
      - nvidia-nvtx-cu12==12.1.105
      - oauthlib==3.2.2
      - omegaconf==2.3.0
      - openai==0.28.1
      - orjson==3.9.10
      - packaging==23.1
      - pandas==2.1.4
      - pathtools==0.1.2
      - pathvalidate==3.2.0
      - peft==0.7.1
      - pillow==9.5.0
      - pip==23.3.2
      - portalocker==2.8.2
      - protobuf==3.20.0
      - psutil==5.9.5
      - py-cpuinfo==9.0.0
      - pyarrow==14.0.2
      - pyarrow-hotfix==0.6
      - pyasn1==0.5.0
      - pyasn1-modules==0.3.0
      - pybind11==2.10.4
      - pycocoevalcap==1.2
      - pycocotools==2.0.7
      - pycountry==23.12.11
      - pydantic==1.10.13
      - pydantic-core==2.14.6
      - pydub==0.25.1
      - pyflakes==3.2.0
      - pynvml==11.5.0
      - pyparsing==3.1.1
      - pyprof==0.0.7
      - pytablewriter==1.2.0
      - python-dotenv==1.0.0
      - python-multipart==0.0.6
      - pytz==2023.3.post1
      - pyyaml==6.0
      - quantile-python==1.1
      - rapidfuzz==3.6.1
      - ray==2.9.0
      - referencing==0.32.1
      - regex==2023.3.23
      - requests-oauthlib==1.3.1
      - responses==0.18.0
      - rich==13.7.0
      - rouge-score==0.1.2
      - rpds-py==0.16.2
      - rsa==4.9
      - s3transfer==0.6.0
      - sacrebleu==1.5.0
      - safetensors==0.4.1
      - scikit-learn==1.3.2
      - scipy==1.11.4
      - semantic-version==2.10.0
      - sentencepiece==0.1.99
      - sentry-sdk==1.22.2
      - setproctitle==1.3.2
      - shellingham==1.5.4
      - smmap==5.0.0
      - sniffio==1.3.0
      - sqlalchemy==2.0.25
      - sqlitedict==2.1.0
      - stanford-stk==0.0.6
      - starlette==0.32.0.post1
      - sympy==1.11.1
      - tabledata==1.3.3
      - tcolorpy==0.1.4
      - tensorboard==2.13.0
      - tensorboard-data-server==0.7.0
      - thefuzz==0.20.0
      - threadpoolctl==3.2.0
      - tiktoken==0.5.2
      - tokenizers==0.15.2
      - tomlkit==0.12.0
      - torch==2.1.2+cu118
      - torchaudio==2.0.1+cu118
      - torchvision==0.15.1+cu118
      - tqdm-multiprocess==0.0.11
      - transformers==4.36.0
      - triton==2.1.0
      - typepy==1.3.2
      - typer==0.9.0
      - tzdata==2023.4
      - uvicorn==0.25.0
      - uvloop==0.19.0
      - vllm==0.2.4+cu118
      - wandb==0.16.2
      - watchfiles==0.21.0
      - websockets==11.0.3
      - werkzeug==2.3.3
      - xformers==0.0.23.post1+cu118
      - xxhash==3.4.1
      - yarl==1.9.4
      - zhconv==1.4.3
prefix: /root/miniconda3

Could you try this environment setup? It's the exact environment I used for my experiments.