Traceback (most recent call last):
File "/data/renky/EasyEdit/test.py", line 21, in <module>
metrics, edited_model_false, _ = editor.edit(
File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 183, in edit
return self.edit_requests(requests, sequential_edit, verbose, test_generation=test_generation, **kwargs)
File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 371, in edit_requests
edited_model, weights_copy, icl_examples = edit_func(request)
File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 319, in edit_func
edited_model, weights_copy = self.apply_algo(
File "/data/renky/EasyEdit/easyeditor/models/memit/memit_main.py", line 46, in apply_memit_to_model
deltas = execute_memit(model, tok, requests, hparams, cache_template=cache_template)
File "/data/renky/EasyEdit/easyeditor/models/memit/memit_main.py", line 137, in execute_memit
cur_z = compute_z(
File "/data/renky/EasyEdit/easyeditor/models/memit/compute_z.py", line 129, in compute_z
logits = model(**input_tok).logits
File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
outputs = self.model(
File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
layer_outputs = decoder_layer(
File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1547, in _call_impl
hook_result = hook(self, args, result)
File "/data/renky/EasyEdit/easyeditor/util/nethook.py", line 80, in retain_hook
output = invoke_with_optional_args(
File "/data/renky/EasyEdit/easyeditor/util/nethook.py", line 454, in invoke_with_optional_args
return fn(*pass_args, **pass_kw)
File "/data/renky/EasyEdit/easyeditor/models/memit/compute_z.py", line 106, in edit_output_fn
cur_out[0][i, idx, :] += delta
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
11.14 修改了错别字和表述问题
由于7b和3b模型用双卡3090在跑的时候都会出现第一张卡爆显存的情况,在4卡3090上跑基于llama-7b的MEMIT后出现如题报错,查看类似issue后未能解决问题。 代码如下:
超参如下:
报错如下:
烦请告知是哪里出了问题或者我需要做哪些更改,谢谢!