zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.77k stars 213 forks source link

Using MEND with monkeypatch from higher #342

Closed xduan7 closed 1 month ago

xduan7 commented 1 month ago

Hi. Thank you for the awesome package. I found something strange when I was trying to use MEND with llama 2 7b. Specifically, the monkey patched llama 2 model is producing different results compared to the original (un-patched) model. Here is a script that demonstrates the problem:

import torch
import loguru
from transformers import AutoModelForCausalLM, AutoTokenizer

from easyeditor import monkeypatch  # Change the path if necessary

logger = loguru.logger

model_name = "meta-llama/Llama-2-7b-hf"
# model_name = "openai-community/gpt2-xl"
modified_param_names = [
    "model.layers.29.mlp.gate_proj.weight",
    "model.layers.29.mlp.up_proj.weight",
    "model.layers.29.mlp.down_proj.weight",
] if "llama" in model_name.lower() else [
    "transformer.h.45.mlp.c_proj.weight",
    "transformer.h.45.mlp.c_fc.weight",
]
device = "auto" if "llama" in model_name.lower() else "cuda:0"

llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device,
)
tkn = AutoTokenizer.from_pretrained(
    model_name,
)
llm.eval()

# Get the logits of a random prompt
prompt = "Who is the president of the United States in 2022?"
tokens = tkn(prompt, return_tensors="pt").to("cuda")
logits = llm.forward(**tokens).logits

# Construct a monkey-patched version of the model
llm = monkeypatch(llm, in_place=True)
logits_after_patch = llm.forward(**tokens).logits

# Since no weights are modified, the logits should be the same
if torch.all(logits == logits_after_patch):
    logger.info("The logits did not change after monkey-patching.")
else:
    logger.error("The logits changed after monkey-patching.")

# Modify the monkey-patched model's weight
# Similar modification from `easyeditor/trainer/alg/MEND` line 399
new_params = []
weight_addition = 1.0
with torch.no_grad():
    for n, p in llm.named_parameters():
        if n in modified_param_names:
            print(f"Adding {weight_addition} to {n} ...")
            new_params.append(p + weight_addition)
        else:
            new_params.append(p)
llm.update_params(new_params)

# Confirm the desired weight change using a fresh model
ref_llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device,
)
for n, p in llm.named_parameters():
    if n in modified_param_names:
        assert torch.all(p == ref_llm.state_dict()[n] + weight_addition)
    else:
        assert torch.all(p == ref_llm.state_dict()[n])

# See if the logits change after weight modification
logits_after_modification = llm.forward(**tokens).logits
if torch.all(logits == logits_after_modification):
    logger.error("The logits did not change after modifying the model.")
else:
    logger.info("The logits changed after modifying the model.")

# Now load the modified model to a fresh model
ref_llm.load_state_dict(llm.state_dict())
llm = ref_llm
ref_llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device,
)
logits_after_loading = llm.forward(**tokens).logits
if torch.all(logits_after_modification == logits_after_loading):
    logger.info(
        "The logits before and after loading the `state_dict` from the"
        " monkey-patched model with weight modification are the same."
    )
else:
    logger.error(
        "Loading the `state_dict` from the monkey-patched model with"
        " weight modification did not result in the same logits as the"
        " said monkey-patched model."
    )
if torch.all(logits == logits_after_loading):
    logger.error(
        "The model returned the same logits after loading the `state_dict`"
        " from the monkey-patched model with weight modification compared to"
        " the original model."
    )
else:
    logger.info(
        "The model returned different logits after loading the `state_dict`"
        " from the monkey-patched model with weight modification compared to"
        " the original model."
    )

# Make sure that the weight modification persists after loading the model
for n, p in llm.named_parameters():
    if n in modified_param_names:
        assert torch.all(p == ref_llm.state_dict()[n] + weight_addition)
    else:
        assert torch.all(p == ref_llm.state_dict()[n])

# Lastly, we check if the logits produced by the original forward method
# can produce the right logits that correspond to the weight modification
logits_with_original_forward = ref_llm.__class__.forward(
    llm,
    **tokens
).logits
if logits_with_original_forward.device != logits_after_loading.device:
    logits_with_original_forward = \
        logits_with_original_forward.to(logits_after_loading.device)
if not torch.all(logits_after_loading == logits_with_original_forward):
    logger.error(
        "The logits after loading the model and"
        " using the original forward method are the same."
    )

# Print all the logits
print("Original logits:\n", logits)
print("Logits after monkey-patching:\n", logits_after_patch)
print(
    "Logits after modification in monkey-patched model:\n",
    logits_after_modification
)
print(
    "Logits after loading the modified monkey-patched model:\n",
    logits_after_loading
)
print(
    "Logits after using the original forward with the loaded model:\n",
    logits_with_original_forward
)

Basically, this script demonstrated that llama 2 model's forward changed after the monkey patch. After a llama 2 model is monkey-patched, weight changes to the model does not lead to output (logits) change. However, by either (1) casting the modified model back to its original class (not monkey-patched), or (2) using the original class forward method (in this case, LlamaForCausalLM.__class__.forward) with the modified model, I was able to generate the correct output (logits) corresponding to the weight changes. This implies that the forward method for monkey-patched llama models are probably not working as intended.

However, it seems that monkey patch works with gpt 2 xl. You can simply swap the model_name in the script to "openai-community/gpt2-xl" for verification that it is only a problem with llama.

I'm wondering if it is absolutely necessary to stick to monkey-patched models in MEND? If so, is there any workaround? And just FYI, I'm using transformers version 4.42.4 and higher version 0.2.1.

Thank you.

pengzju commented 1 month ago

Thank you for providing such a detailed script to reproduce certain issues with monkeypatch. In fact, the use of monkeypatch can be traced back to the official implementation of MEND: https://github.com/eric-mitchell/mend/blob/main/algs/mend.py#L230.

However, in past experiments, we have observed successful editing of Llama-7B with MEND (Edit success = 1.0). Therefore, monkeypatch is unlikely to cause logits to stop changing. Please rest assured that I will attempt to reproduce this asap.

zxlzr commented 1 month ago

Hi, do you have any further issues?

xduan7 commented 1 month ago

@zxlzr Sorry for the late reply but my question is not answered yet.

I understand @pengzju 's statement that MEND worked fine in your evaluations. So I'm wondering if you could reproduce the results in my script. If so, my guess would be that it is an issue of package versioning. Otherwise, I'm wondering if monkeypatch is necessary for MEND. From what I know, make_functional actually does nothing other than wrapping up the stateless model with its forward function. I'm thinking since other knowledge editing methods are not using it, it might not be necessary for MEND as well.

zxlzr commented 1 month ago

Sorry, we are busy in the rebuttal stage, we will handle this issue in the next few days.

xduan7 commented 1 month ago

No worries. Take your time. Hope your rebuttal goes well.

pengzju commented 1 month ago

@xduan7

Thank you so much for providing detailed code to tackle this tricky bug. As I mentioned before, many past experiments have shown that MEND can effectively modify weights.

I did manage to reproduce your issue, but the problem isn't that "it doesn't work on LLaMA"; rather, the monkeypatch loses its effectiveness when the model is run in parallel. In the code you provided, simply changing device to 'cpu' or a single GPU instead of 'auto' allows all tests to pass.

I guess that the underlying issue is an incompatibility between the HuggingFace Transformers parallel model logic and the monkeypatch. This means that as long as you ensure the model runs on a single GPU, you can avoid the issue. That said, we probably can't completely eliminate this issue, but I’ve provided a way to avoid the bug.

I really enjoy our discussion and hope this resolves your problem.

zxlzr commented 1 month ago

hi, do you have any further questions?

xduan7 commented 1 month ago

@pengzju @zxlzr Thank you for the reply. This is very helpful and I can confirm that using a single GPU solves the problem. Before we close the issue, I would like to extend the discussion a little bit. The purpose of monkey patch seems to be making the model stateless and wrapped to prevent further modification, so removing it should not impact the knowledge editing results. Do you think it's a good idea to simply remove monkey patch in MEND?

pengzju commented 1 month ago

I think this is a good idea. In my previous experience, external patch and model parallelism have often conflicted. I would also encourage you to try directly modifying the model's weights. I'd welcome your feedback or a PR submission if you can get it to work on multiple GPUs.

Thanks again!

xduan7 commented 1 month ago

Thank you for everything.