zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.88k stars 232 forks source link

Question about multimodal edit base model #217

Closed luludus closed 7 months ago

luludus commented 7 months ago

why is the base model's Reliability is zero? I use the code like :

eval_loader = DataLoader(eval_ds, batch_size=1, shuffle=True, collate_fn=eval_ds.collate_fn) with torch.no_grad(): for i, batch in enumerate(eval_loader):

print(batch)

inner_edit_outputs = model(batch["edit_inner"]) inner_batch_labels = batch["edit_inner"]["labels"]

print(inner_batch_labels)

if not isinstance(inner_edit_outputs, torch.Tensor): inner_edit_logits = inner_edit_outputs.logits else: inner_edit_logits = inner_edit_outputs if inner_edit_logits.shape[1] > inner_batch_labels.shape[1]:

print('1')

inner_edit_dict = masked_log_probs(hparams, inner_edit_logits, inner_batch_labels) else: inner_edit_dict = masked_log_probs(hparams, inner_edit_logits, inner_batch_labels[:, -inner_edit_logits.shape[1]-1:])

print(inner_edit_dict)

print(inner_edit_dict['acc'])

the masked_log_probs is in trainer/losses.py but the result like: tensor(0.6154, device='cuda:1') tensor(0.4615, device='cuda:1') tensor(0.0909, device='cuda:1') tensor(0.2727, device='cuda:1') tensor(0.5000, device='cuda:1') tensor(0.2308, device='cuda:1') tensor(0.4615, device='cuda:1') tensor(0.3000, device='cuda:1') tensor(0.4545, device='cuda:1') tensor(0.3333, device='cuda:1') tensor(0.3000, device='cuda:1') tensor(0.2222, device='cuda:1') tensor(0.3077, device='cuda:1') tensor(0.3846, device='cuda:1') tensor(0.3000, device='cuda:1') tensor(0.3636, device='cuda:

tbozhong commented 7 months ago

When constructing our dataset, we filtered out instances where the model's predictions did not exactly match the labels, which differs from our evaluation criteria (Accuracy) for editing methods.

Our evaluation criterion for all methods in multimodal editing indeed requires an exact match between the model's predictions and the labels, which is a more stringent standard than what is used by EasyEdit.

luludus commented 7 months ago

When constructing our dataset, we filtered out instances where the model's predictions did not exactly match the labels, which differs from our evaluation criteria (Accuracy) for editing methods.

thanks for you reply !

I have this question because there is a little difference between my results and yours when reproducing mend. And my performance when using the KE method is significantly higher than your results on generalizability and reproducibility . I don't know what causes this.

Besides,what method do you use to evaluate the reproducibility of the base model?

tbozhong commented 7 months ago

You can refer to here for evaluating base model. As for KE, please give me more details (such as backbone model, training log and which metrics, reproducibility is not used in our experiments) for reproducing your issue.

luludus commented 7 months ago

You can refer to here for evaluating base model. As for KE, please give me more details (such as backbone model, training log and which metrics, reproducibility is not used in our experiments) for reproducing your issue.

hi, My method of evaluating the base model is consistent with the link you gave. The first question is my result on the E-IC data set. Why is this?

I'm sorry I referred to reliability as reproducibility. This is the result of my KE on E-IC. And i edit the KE.edit as : `

def edit(self, batch, condition, detach_history=False):
    if 'minigpt4' in self.config.model_name.lower() or 'blip' in self.config.model_name.lower():
        # print("!!!!!!")

        # print(batch['image'].device)
        outputs = self.model(batch) 

        if not isinstance(outputs, torch.Tensor):
            # batch_labels = outputs.labels
            outputs = outputs.logits
        loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]          
    elif 'gpt' in self.config.model_name.lower():
        outputs = _logits(self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']))
        # outputs = outputs[:, -batch['labels'].shape[-1]:, :]
        if not kwargs:
            loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]
        else:
            loss = self.edit_loss_fn(self.config, outputs, batch["labels"], **kwargs)["nll"]
    elif 'llama' in self.config.model_name.lower():
        outputs = _logits(self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']))
        # outputs = outputs[:, -batch['labels'].shape[-1]:, :]
        if not kwargs:
            loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]
        else:
            loss = self.edit_loss_fn(self.config, outputs, batch["labels"], **kwargs)["nll"]
    elif 'baichuan' in self.config.model_name.lower():
        outputs = _logits(self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']))
        # outputs = outputs[:, -batch['labels'].shape[-1]:, :]
        loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"] 
    elif 'chatglm2' in self.config.model_name.lower():
        outputs = _logits(self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']))
        # outputs = outputs[:, -batch['labels'].shape[-1]:, :]
        loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]            
    elif 'internlm' in self.config.model_name.lower():
        outputs = _logits(self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']))
        # outputs = outputs[:, -batch['labels'].shape[-1]:, :]
        loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]  
    elif 'qwen' in self.config.model_name.lower():
        outputs = _logits(self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']))
        # outputs = outputs[:, -batch['labels'].shape[-1]:, :]
        loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]         
    elif 'mistral' in self.config.model_name.lower():
        outputs = _logits(self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']))
        # outputs = outputs[:, -batch['labels'].shape[-1]:, :]
        loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]  
    else:
        outputs = _logits(self.model(**batch))
        loss = self.edit_loss_fn(self.config, outputs, batch["labels"])["nll"]

    names = set([n for n, p in self.model.named_parameters()])
    pset = set(self.config.inner_params)
    for p in pset:
        assert p in names, f"inner param {p} not in model"

    grads = torch.autograd.grad(
        loss,
        [p for (n, p) in _inner_params(self.model.named_parameters(), self.config.inner_params)]
    )

    params_dict = self.editor(
        condition["input_ids"] if condition is not None else batch["input_ids"],
        condition["attention_mask"] if condition is not None else batch["attention_mask"],
        {n: g.to(torch.float32) for (n, g) in zip(self.config.inner_params, grads)},
    )

    edited_model = self.model
    if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
        edited_model = make_functional(edited_model, in_place=True)

    def new_param(n, p):
        if n not in params_dict:
            return p

        if p.shape[0] == params_dict[n].shape[0]:
            return p + params_dict[n]
        else:
            return p + params_dict[n].T

    edited_model.update_params(
        [new_param(n, p) for (n, p) in edited_model.named_parameters()]
    )

    if detach_history:
        new_model = self.model_constructor()
        new_model.load_state_dict(edited_model.state_dict())
        edited_model = new_model

    return KE(edited_model, self.config, self.model_constructor, editor=self.editor), {}

`

03/25/2024 16:07:34 - INFO - easyeditor.trainer.BaseTrainer - loss/edit_train: 2.35340; loss/image_edit_train: 2.91150; loss/loc_train: 0.00241; edit/acc_train: 0.49732; edit/log_prob_train: -2.35340; edit/prob_train: 0.30464; inner/acc_train: 0.49095; image_rephrase/acc_train: 0.41847; time/edit_train: 0.32044; loc/acc_train: 0.98176; image_loc/acc_train: 0.65422; loss/total_train: 0.54402; loss/total_edit_train: 0.54402; memory/alloc_max_train: 19263224832.00000; memory/res_max_train: 21787312128.00000; grad_train: 0.10190 03/25/2024 16:09:28 - INFO - easyeditor.trainer.BaseTrainer - Step 4700: 03/25/2024 16:09:28 - INFO - easyeditor.trainer.BaseTrainer - loss/edit_train: 2.19381; loss/image_edit_train: 2.65857; loss/loc_train: 0.00195; edit/acc_train: 0.51908; edit/log_prob_train: -2.19381; edit/prob_train: 0.32686; inner/acc_train: 0.51150; image_rephrase/acc_train: 0.47261; time/edit_train: 0.32114; loc/acc_train: 0.98419; image_loc/acc_train: 0.65937; loss/total_train: 0.50039; loss/total_edit_train: 0.50039; memory/alloc_max_train: 19263569510.40000; memory/res_max_train: 21787312128.00000; grad_train: 0.07531 03/25/2024 16:11:21 - INFO - easyeditor.trainer.BaseTrainer - Step 4800: 03/25/2024 16:11:21 - INFO - easyeditor.trainer.BaseTrainer - loss/edit_train: 2.38431; loss/image_edit_train: 2.84697; loss/loc_train: 0.00230; edit/acc_train: 0.51983; edit/log_prob_train: -2.38431; edit/prob_train: 0.31777; inner/acc_train: 0.51993; image_rephrase/acc_train: 0.44956; time/edit_train: 0.31707; loc/acc_train: 0.98219; image_loc/acc_train: 0.65879; loss/total_train: 0.53873; loss/total_edit_train: 0.53873; memory/alloc_max_train: 19266358272.00000; memory/res_max_train: 21787312128.00000; grad_train: 0.10020 03/25/2024 16:13:15 - INFO - easyeditor.trainer.BaseTrainer - Step 4900: 03/25/2024 16:13:15 - INFO - easyeditor.trainer.BaseTrainer - loss/edit_train: 2.29717; loss/image_edit_train: 2.74778; loss/loc_train: 0.00237; edit/acc_train: 0.52210; edit/log_prob_train: -2.29717; edit/prob_train: 0.32702; inner/acc_train: 0.53358; image_rephrase/acc_train: 0.45422; time/edit_train: 0.31937; loc/acc_train: 0.98191; image_loc/acc_train: 0.66626; loss/total_train: 0.51935; loss/total_edit_train: 0.51935; memory/alloc_max_train: 19266358272.00000; memory/res_max_train: 21787312128.00000; grad_train: 0.07744 03/25/2024 16:15:09 - INFO - easyeditor.trainer.BaseTrainer - Step 5000: 03/25/2024 16:15:09 - INFO - easyeditor.trainer.BaseTrainer - loss/edit_train: 2.30292; loss/image_edit_train: 2.70138; loss/loc_train: 0.00215; edit/acc_train: 0.51670; edit/log_prob_train: -2.30292; edit/prob_train: 0.33341; inner/acc_train: 0.52501; image_rephrase/acc_train: 0.46685; time/edit_train: 0.32161; loc/acc_train: 0.95933; image_loc/acc_train: 0.65356; loss/total_train: 0.51603; loss/total_edit_train: 0.51603; memory/alloc_max_train: 19266358272.00000; memory/res_max_train: 21787312128.00000; grad_train: 0.08240 03/25/2024 16:23:45 - INFO - easyeditor.trainer.BaseTrainer - Step 5000: 03/25/2024 16:23:45 - INFO - easyeditor.trainer.BaseTrainer - loss/edit_val: 2.19956; loss/image_edit_val: 2.66498; loss/loc_val: 0.00235; edit/acc_val: 0.54066; edit/log_prob_val: -2.19956; edit/prob_val: 0.34405; inner/acc_val: 0.54251; image_rephrase/acc_val: 0.48539; time/edit_val: 0.32371; loc/acc_val: 0.94785; image_loc/acc_val: 0.65332; loss/total_val: 0.50227; loss/total_edit_val: 0.50227; memory/alloc_max_val: 19856511186.94400; memory/res_max_val: 23193242828.80000; eval_time/elapsed: 515.98927; eval_time/average: 1.03198 03/25/2024 16:23:45 - INFO - easyeditor.trainer.BaseTrainer - Saving model to ./results/models/KE/blip2

tbozhong commented 7 months ago

Thank you for bringing this to our attention. Upon reviewing our original code, I realize there has been a misunderstanding. For clarity, our evaluation criterion for all methods indeed requires an exact match between the model's predictions and the labels, which is a more stringent standard than what is used by EasyEdit. To align with our evaluation protocol and accurately reproduce our results, please update the code in trainer/loss.py at line 55 as follows:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    correct = correct & mask
    num_non_padding = mask.sum().float().item()

    if 't5' in config.model_class.lower():
        end_mask = targ != 1
        correct = correct & end_mask
        num_non_padding = (mask & end_mask).sum().float().item()
    acc = correct.sum() / num_non_padding

    # ... [other code] ...

Replace it with the following updated evaluation logic:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    if pred.dim() == 3:
        correct = (pred_ids == targ).all(-1)  # We aim for an exact match across the entire sequence
    acc = correct.float().mean()

    # ... [other code] ...

This modification ensures that our accuracy metric reflects the requirement for predictions to be entirely correct across the full sequence, not just at individual token positions. We apologize for any confusion and appreciate your cooperation in maintaining the integrity of our evaluation standards.

luludus commented 7 months ago

Thank you for your sincere recovery Can you give me some advice about ke code? I use the BLip2opt model and modify the KE.edit as shown in the previous comment. and ke.init as follows, : Is my code wrong? `class KE(EditableModel): def init(self, model, config, model_constructor, editor=None): # super().init(model, config, model_constructor)

    if editor is None:
        if isinstance(model, BertClassifier):
            embedding = model.model.embeddings.word_embeddings.weight.data
        elif isinstance(model, BartForConditionalGeneration):
            embedding = model.model.shared.weight.data
        elif isinstance(model, T5ForConditionalGeneration):
            embedding = model.shared.weight.data
        else:
            embedding = model.opt_model.model.decoder.embed_tokens.weight.data #wte.weight.data

        editor = OneShotLearner(model, vocab_dim=model.opt_model.config.vocab_size,
                                include_set=config.inner_params,
                                embedding_dim=embedding.shape[-1],
                                embedding_init=embedding.clone().to(torch.float32),
                                max_scale=1)
    self.editor = editor

`

tbozhong commented 7 months ago

Your initial KE looks correct. Could you give me your code for evaluation?

luludus commented 7 months ago

ke.txt

this is my KE.py placed in easyeditor/trainer/algs/KE.py

tbozhong commented 7 months ago

To align with our evaluation protocol and accurately reproduce our results, please update the code in trainer/loss.py at line 55 as follows:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    correct = correct & mask
    num_non_padding = mask.sum().float().item()

    if 't5' in config.model_class.lower():
        end_mask = targ != 1
        correct = correct & end_mask
        num_non_padding = (mask & end_mask).sum().float().item()
    acc = correct.sum() / num_non_padding

    # ... [other code] ...

Replace it with the following updated evaluation logic:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    if pred.dim() == 3:
        correct = (pred_ids == targ).all(-1)  # We aim for an exact match across the entire sequence
    acc = correct.float().mean()

    # ... [other code] ...

Please try this to reproduce. The way of evaluation will influence the results.

asdfo123 commented 7 months ago

To align with our evaluation protocol and accurately reproduce our results, please update the code in trainer/loss.py at line 55 as follows:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    correct = correct & mask
    num_non_padding = mask.sum().float().item()

    if 't5' in config.model_class.lower():
        end_mask = targ != 1
        correct = correct & end_mask
        num_non_padding = (mask & end_mask).sum().float().item()
    acc = correct.sum() / num_non_padding

    # ... [other code] ...

Replace it with the following updated evaluation logic:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    if pred.dim() == 3:
        correct = (pred_ids == targ).all(-1)  # We aim for an exact match across the entire sequence
    acc = correct.float().mean()

    # ... [other code] ...

Please try this to reproduce our results. The way of evaluation will influence the results.

Hi, I mentioned that the way of evaluation used in compute_multimodal_edit_quality() in evaluate/evaluate.py at line 442 is the same as the former one as follows:

def compute_multimodal_edit_quality(model, batch):

    with torch.no_grad():
        outputs = model(batch)
        if isinstance(outputs, torch.Tensor):
            logits = outputs.detach().cpu()
        else:
            logits = outputs.logits.detach().cpu()    
        # targ = outputs.labels.detach().cpu()
        targ = batch["labels"].cpu()
    if logits.dim() == 3:
        logits = logits[:, :-1]
        # targ = targ[:, 1:]
        logits = logits[:, -targ.shape[1]:]
    mask = targ != -100
    targ[~mask] = 0
    pred_ids = logits.argmax(-1).masked_fill(~mask, 0).detach().cpu()
    correct = pred_ids == targ
    correct = correct & mask
    num_non_padding = mask.sum().float().item()
    acc = correct.sum() / num_non_padding

    return acc, pred_ids.numpy() 

Should I also change this to align with your evaluation protocol?

luludus commented 7 months ago

To align with our evaluation protocol and accurately reproduce our results, please update the code in trainer/loss.py at line 55 as follows:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    correct = correct & mask
    num_non_padding = mask.sum().float().item()

    if 't5' in config.model_class.lower():
        end_mask = targ != 1
        correct = correct & end_mask
        num_non_padding = (mask & end_mask).sum().float().item()
    acc = correct.sum() / num_non_padding

    # ... [other code] ...

Replace it with the following updated evaluation logic:

def multiclass_log_probs(config, pred, targ, shift=False, eps=torch.finfo(torch.float32).eps, **kwargs):
    # ... [other code] ...

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    if pred.dim() == 3:
        correct = (pred_ids == targ).all(-1)  # We aim for an exact match across the entire sequence
    acc = correct.float().mean()

    # ... [other code] ...

Please try this to reproduce our results. The way of evaluation will influence the results.

thanks for you reply , The results appear to be normal. Are the results of all methods in multimodal model editing calculated by the modified multiclass_log_probs?

tbozhong commented 7 months ago

Yes, the multimodal results of all methods should use the exact match evaluation protocol.

I will update the code as soon as possible.

luludus commented 7 months ago

Yes, the multimodal results of all methods should use the exact match evaluation protocol.

I will update the code as soon as possible.

Does loss calculation during training also use modified multiclass_log_probs?

tbozhong commented 7 months ago

The loss calculation has remained unchanged.

luludus commented 7 months ago

Well, thank you very much

luludus commented 7 months ago

hi, I trained on blip2 using the mend method, and with the modified code the result was Reliability: 72.7 Reliability 48.3

and my training results.json is: {"results": {"loss/edit_val": 0.07809163601696854, "loss/image_edit_val": 0.17886119993117608, "loss/loc_val": 0.0006802952955840737, "edit/acc_val": 0.9795335837006569, "edit/log_prob_val": -0.07809163601696854, "edit/prob_val": 0.9682461228370667, "inner/acc_val": 0.9809760047197342, "image_rephrase/acc_val": 0.9527982782125473, "time/edit_val": 0.41299159598350527, "loc/acc_val": 0.9900922179222107, "image_loc/acc_val": 0.8431490063667297, "loss/total_val": 0.02818046481552301, "loss/total_edit_val": 0.02818046481552301, "memory/alloc_max_val": 0.0, "memory/res_max_val": 0.0, "eval_time/elapsed": 1126.2992458343506, "eval_time/average": 1.1262992458343506}}

Is there something wrong with my training ?

tbozhong commented 7 months ago

I see the results are normal, can you detail the issue?

luludus commented 7 months ago

the result is different from the paper: " Can We Edit Multimodal Large Language Models?” I used the same dataset ,and use the code like multimodal_editor.py metrics, edited_model, _ = editor.edit( prompts=prompts, targets=targets, image=image, rephrase_prompts=rephrase_prompts, rephrase_image=rephrase_image, locality_inputs=locality_inputs, keep_original_weight=True )

tbozhong commented 7 months ago

As for MEND and SERAC, please use trainer not editor to reproduce. You can refer to *acc_eval in results.json mentioned above. Sorry for any confusion, I will update README for reproducing our results recently.

zxlzr commented 7 months ago

Sorry for the inconvenience caused by the overly brief description on the main README.md. We will provide an MMEdit.md as soon as possible to help reproducing the results.

luludus commented 7 months ago

As for MEND and SERAC, please use trainer not editor to reproduce. You can refer to *acc_eval in results.json mentioned above. Sorry for any confusion, I will update README for reproducing our results recently.

Isn't the result of result.json in the code the same as the loss result calculated during training? Not using the modified code in multiclass_log_probs ?

If the train loss remains unchanged, then the multiclass_log_probs cannot change, right? Should we add a new function during evaluation?

tbozhong commented 7 months ago

Hi👋, I'm sorry for taking up your time. If you have further questions, please add my WeChat YouKn0wWho for further discussion.

zxlzr commented 7 months ago

Hi guys,

Due to the incomplete README for the multimodal editing tasks, many have encountered issues while running the code. We apologize for this inconvenience. We will release an README to facilitate code execution. As a new task in the era of LLMs, we have been dedicated to designing reasonable evaluation metrics, which has led to the rapid iteration of EasyEdit. Here, we have two settings: exact match and accuracy (default setting in EasyEdit). While there are differences in the results, they do not affect the main conclusion in editing multimodal LLMs. We will report the results under different settings ASAP.

Due to the unique nature of multimodal tasks, there are slight differences in the settings, for which we sincerely apologize for any inconvenience. We will temporary close this issue. If you have any inquiries or need assistance, please don't hesitate to reach out to us. We are committed to providing ongoing support and maintenance for EasyEdit. If there are any questions, please add the WeChat YouKn0wWho for further discussion.

EasyEdit Team

zxlzr commented 6 months ago

Hi guys, we have reported the Exact Match/Accuracy results of Reliability and T-Generality at https://arxiv.org/abs/2310.08475, if there are any issue, feel free to contact us.