zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.93k stars 236 forks source link

多卡并行报错:Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! #419

Open Darknessrky opened 3 days ago

Darknessrky commented 3 days ago

11.14 修改了错别字和表述问题

由于7b和3b模型用双卡3090在跑的时候都会出现第一张卡爆显存的情况,在4卡3090上跑基于llama-7b的MEMIT后出现如题报错,查看类似issue后未能解决问题。 代码如下:

import sys
import os
import json
from easyeditor.editors.editor import BaseEditor
from easyeditor import MEMITHyperParams

sys.path.append('/data/renky/EasyEdit')
os.chdir("/data/renky/EasyEdit")
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"

# load zsre_data
edit_data = json.load(open('./data/ZsRE/ZsRE-test-all.json', 'r', encoding='utf-8'))[:100]
prompts = [edit_data_['prompt'] for edit_data_ in edit_data]
ground_truth = [edit_data_['ground_truth'][0] for edit_data_ in edit_data]  
subject = [edit_data_['subject'] for edit_data_ in edit_data]
target_new = [edit_data_['target_new'] for edit_data_ in edit_data]

# MEMIT
hparams=MEMITHyperParams.from_hparams('./hparams/MEMIT/llama-7b.yaml')
editor = BaseEditor.from_hparams(hparams)
metrics, edited_model_false, _ = editor.edit(
    prompts=prompts,
    ground_truth=ground_truth,
    target_new=target_new,
    subject=subject,
    keep_original_weight=False
)
print(metrics)

超参如下:

alg_name: "MEMIT"
model_name: "./hugging_cache/llama-2-7b"
stats_dir: "./data/stats"
device: 0
layers: [4, 5, 6, 7, 8]
clamp_norm_factor: 4
layer_selection: "all"
fact_token: "subject_last"
v_num_grad_steps: 25
v_lr: 5e-1
v_loss_layer: 31
v_weight_decay: 1e-3
kl_factor: 0.0625
mom2_adjustment: true
mom2_update_weight: 15000
rewrite_module_tmp: "model.layers.{}.mlp.down_proj"
layer_module_tmp: "model.layers.{}"
mlp_module_tmp: "model.layers.{}.mlp"
attn_module_tmp: "model.layers.{}.self_attn"
ln_f_module: "model.norm"
lm_head_module: "lm_head"
mom2_dataset: "wikipedia"
mom2_n_samples: 100000
mom2_dtype: "float32"
model_parallel: true

报错如下:

Traceback (most recent call last):
  File "/data/renky/EasyEdit/test.py", line 21, in <module>
    metrics, edited_model_false, _ = editor.edit(
  File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 183, in edit
    return self.edit_requests(requests, sequential_edit, verbose, test_generation=test_generation, **kwargs)
  File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 371, in edit_requests
    edited_model, weights_copy, icl_examples = edit_func(request)
  File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 319, in edit_func
    edited_model, weights_copy = self.apply_algo(
  File "/data/renky/EasyEdit/easyeditor/models/memit/memit_main.py", line 46, in apply_memit_to_model
    deltas = execute_memit(model, tok, requests, hparams, cache_template=cache_template)
  File "/data/renky/EasyEdit/easyeditor/models/memit/memit_main.py", line 137, in execute_memit
    cur_z = compute_z(
  File "/data/renky/EasyEdit/easyeditor/models/memit/compute_z.py", line 129, in compute_z
    logits = model(**input_tok).logits
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1547, in _call_impl
    hook_result = hook(self, args, result)
  File "/data/renky/EasyEdit/easyeditor/util/nethook.py", line 80, in retain_hook
    output = invoke_with_optional_args(
  File "/data/renky/EasyEdit/easyeditor/util/nethook.py", line 454, in invoke_with_optional_args
    return fn(*pass_args, **pass_kw)
  File "/data/renky/EasyEdit/easyeditor/models/memit/compute_z.py", line 106, in edit_output_fn
    cur_out[0][i, idx, :] += delta
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

烦请告知是哪里出了问题或者我需要做哪些更改,谢谢!

XeeKee commented 1 day ago

我在本地测试是可以正常run的,但是也会爆显存。 可以试试给模型开量化,或者换到更大显存的机器

Darknessrky commented 1 day ago

小于等于2卡就能正常run但是爆显存,大于等于3卡就报上述错误hhh,我后面再试试看看能不能解决

zxlzr commented 4 hours ago

hi, do you have any further issues?