zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.92k stars 236 forks source link

How to reproduce the multi-modal results using MEND? #186

Closed lliutianc closed 8 months ago

lliutianc commented 8 months ago

Hi there,

Thanks for your great work! I am currently trying to edit two provided MMLMs with MEND, but similar to #177, the resultant rewrite_acc, rephrase_acc, and rephrase_image_acc are lower than 0.1.

For your information, I am using the following codes:

def train_MEND_Blip2OPT_VQA(debug=False):
    if debug: 
        size = 100
    else:
        size = None

    hparams = MENDMultimodalTrainingHparams.from_hparams('hparams/TRAINING/MEND/blip2_local.yaml')
    train_ds = VQADataset('data/vqa/vqa_train.json', config=hparams, size=size)
    eval_ds = VQADataset('data/vqa/vqa_eval.json', config=hparams, size=size)

    trainer = MultimodalTrainer(
        config=hparams,
        train_set=train_ds,
        val_set=eval_ds
    )

    trainer.run()   

def edit_MEND_Blip2OPT_VQA(debug=False):
    hparams = MENDMultimodalHparams.from_hparams('hparams/MEND/blip2_local.yaml')
    editor = MultimodalEditor.from_hparams(hparams)

    if debug: 
        size = 100
    else:
        size = None

    train_ds = VQADataset('data/vqa/vqa_train.json', config=hparams, size=size)
    eval_ds = VQADataset('data/vqa/vqa_eval.json', config=hparams, size=size)

    metrics, edited_model, _ = editor.edit_dataset(
        ds=eval_ds,
        train_ds=train_ds,
        keep_original_weight=True        
    )

    print_result(metrics)

I reorganized the folder and replaced the pre-trained OPT and Vicuna with the ones that can be downloaded from HF directly:

name: lmsys/vicuna-7b-v1.3
model_name: minigpt4
model_class: Blip2OPT
tokenizer_class: LlamaTokenizer
tokenizer_name: lmsys/vicuna-7b-v1.3

name: lmsys/vicuna-7b-v1.3
model_name: minigpt4
model_class: Blip2OPT
tokenizer_class: LlamaTokenizer
tokenizer_name: lmsys/vicuna-7b-v1.3

Due to these modifications my params are named as XX_local.yaml.

lliutianc commented 8 months ago

Sorry I typed the minigpt4 twice. The config for Blip2 I used is

name: facebook/opt-2.7b
model_name: blip2
model_class: Blip2OPT
tokenizer_class: GPT2Tokenizer
tokenizer_name: facebook/opt-2.7b
tbozhong commented 8 months ago

Sorry, I'm unable to reproduce the issue you're experiencing. Your configuration appears to be correct. Could you provide more detailed information to assist in troubleshooting?

lliutianc commented 8 months ago

Thanks for your reply!

Another thing I forgot to mention is that I cannot run edit_dataset but can run edit (on the provided specific sample in the edit_xxx functions), which may attribute to not loading data to the GPU. So I made a tiny patch to the __getitem__ method of the VQA Dataset:

    def __getitem__(self, index):
        # print(self._data[index])
        # exit(1)
        data = self._data[index]
        data = dict_to(data, self.config.device)
        return data

But I don't think this modification could explain the difference.

I am using float precision.

The complete list of my environment is as follows:

name: easyedit
channels:
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - asttokens=2.4.1=pyhd8ed1ab_0
  - bzip2=1.0.8=hd590300_5
  - ca-certificates=2024.2.2=hbcca054_0
  - comm=0.2.1=pyhd8ed1ab_0
  - debugpy=1.6.7=py39h6a678d5_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - executing=2.0.1=pyhd8ed1ab_0
  - importlib_metadata=7.0.1=hd8ed1ab_0
  - ipykernel=6.29.2=pyhd33586a_0
  - ipython=8.18.1=pyh707e725_3
  - jedi=0.19.1=pyhd8ed1ab_0
  - jupyter_client=8.6.0=pyhd8ed1ab_0
  - jupyter_core=5.7.1=py39hf3d152e_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.2=h7f98852_5
  - libgcc-ng=13.2.0=h807b86a_5
  - libgomp=13.2.0=h807b86a_5
  - libnsl=2.0.1=hd590300_0
  - libsodium=1.0.18=h36c2ea0_1
  - libsqlite=3.45.1=h2797004_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=2.38.1=h0b41bf4_0
  - libxcrypt=4.4.36=hd590300_1
  - libzlib=1.2.13=hd590300_5
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - ncurses=6.4=h6a678d5_0
  - nest-asyncio=1.6.0=pyhd8ed1ab_0
  - openssl=3.2.1=hd590300_0
  - packaging=23.2=pyhd8ed1ab_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pexpect=4.9.0=pyhd8ed1ab_0
  - pickleshare=0.7.5=py_1003
  - pip=23.3.1=py39h06a4308_0
  - platformdirs=4.2.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.42=pyha770c72_0
  - psutil=5.9.8=py39hd1e30aa_0
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pygments=2.17.2=pyhd8ed1ab_0
  - python=3.9.18=h0755675_1_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.9=4_cp39
  - pyzmq=25.1.2=py39h6a678d5_0
  - readline=8.2=h5eee18b_0
  - setuptools=68.2.2=py39h06a4308_0
  - six=1.16.0=pyh6c4a22f_0
  - sqlite=3.41.2=h5eee18b_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tk=8.6.13=noxft_h4845f30_101
  - tornado=6.4=py39hd1e30aa_0
  - traitlets=5.14.1=pyhd8ed1ab_0
  - typing_extensions=4.9.0=pyha770c72_0
  - tzdata=2023d=h04d1e81_0
  - wcwidth=0.2.13=pyhd8ed1ab_0
  - wheel=0.41.2=py39h06a4308_0
  - xz=5.4.5=h5eee18b_0
  - zeromq=4.3.5=h6a678d5_0
  - zipp=3.17.0=pyhd8ed1ab_0
  - zlib=1.2.13=hd590300_5
  - pip:
      - accelerate==0.27.2
      - aiohttp==3.9.3
      - aiosignal==1.3.1
      - antlr4-python3-runtime==4.8
      - async-timeout==4.0.3
      - attrs==23.2.0
      - blessed==1.20.0
      - certifi==2024.2.2
      - charset-normalizer==3.3.2
      - click==8.1.7
      - cmake==3.28.3
      - cycler==0.12.1
      - datasets==1.18.3
      - dill==0.3.8
      - einops==0.4.0
      - fairscale==0.4.13
      - filelock==3.13.1
      - fonttools==4.48.1
      - frozenlist==1.4.1
      - fsspec==2024.2.0
      - gpustat==1.1
      - higher==0.2.1
      - huggingface-hub==0.20.3
      - hydra-core==1.1.1
      - idna==3.6
      - importlib-metadata==6.3.0
      - iopath==0.1.10
      - jinja2==3.1.3
      - joblib==1.3.2
      - kiwisolver==1.4.5
      - lit==17.0.6
      - markupsafe==2.1.5
      - matplotlib==3.5.1
      - mpmath==1.3.0
      - multidict==6.0.5
      - multiprocess==0.70.16
      - networkx==3.2.1
      - nltk==3.6.5
      - numpy==1.22.1
      - nvidia-cublas-cu11==11.10.3.66
      - nvidia-cuda-cupti-cu11==11.7.101
      - nvidia-cuda-nvrtc-cu11==11.7.99
      - nvidia-cuda-runtime-cu11==11.7.99
      - nvidia-cudnn-cu11==8.5.0.96
      - nvidia-cufft-cu11==10.9.0.58
      - nvidia-curand-cu11==10.2.10.91
      - nvidia-cusolver-cu11==11.4.0.1
      - nvidia-cusparse-cu11==11.7.4.91
      - nvidia-ml-py==12.535.133
      - nvidia-nccl-cu11==2.14.3
      - nvidia-nvtx-cu11==11.7.91
      - omegaconf==2.1.1
      - openai==0.27.9
      - opencv-python==4.8.0.76
      - pandas==1.4.0
      - peft==0.4.0.dev0
      - pillow==10.2.0
      - portalocker==2.8.2
      - pyarrow==15.0.0
      - pyparsing==3.1.1
      - pytz==2024.1
      - pyyaml==6.0
      - regex==2023.12.25
      - requests==2.31.0
      - safetensors==0.4.2
      - scikit-learn==1.0.2
      - scipy==1.7.3
      - sentence-transformers==2.2.2
      - sentencepiece==0.1.99
      - sympy==1.12
      - threadpoolctl==3.3.0
      - timm==0.9.7
      - tokenizers==0.13.3
      - torch==2.0.1
      - torchvision==0.15.2
      - tqdm==4.62.3
      - transformers==4.30.1
      - triton==2.0.0
      - urllib3==2.2.0
      - xxhash==3.4.1
      - yarl==1.9.4
tbozhong commented 8 months ago

Are you using the latest version of the code? You can use the train_MEND_MiniGPT4_Caption() function in multimodal_edit.py to view the results without edit_dataset.

zxlzr commented 8 months ago

Hi, have you solved your issue yet?

lliutianc commented 8 months ago

Sorry but i am using the latest version and I have not solved it yet😂

tbozhong commented 8 months ago

How about the results of using train_MEND_MiniGPT4_Caption()?

lliutianc commented 8 months ago

Thanks for asking but I am still testing the codes. ~The data preparation is pretty slow in fact😂~

Sorry, but directly running the codes on an A6000 with 48GB encountered the out of memory issue...What device did you use?

Update: I updated my modification here.

Could you run the codes and see if the issue exists? The environment was provided above. Experiments with minigpt4 and blip2 are in edit_blip2.py and edit_minigpt4.py

tbozhong commented 8 months ago

We conduct our experiments on an A800 with 80GB. Perhaps you could start by trying the train_MEND_Blip2OPT_VQA() function. I will also run your code as soon as possible.

zxlzr commented 8 months ago

Hi, have you solved your issue yet?

lliutianc commented 8 months ago

Sorry I am testing the new codes. I will keep you posted once I figured out the issue. Thanks!

lliutianc commented 8 months ago

Hi,

I am still running train_MEND_Blip2OPT_Caption.

It seems that this function will overwrite the result from train_MEND_Blip2OPT_VQA. Nonetheless, the trained editor model can be loaded in edit_MEND_Blip2OPT_VQA. Surprisingly, the performance of the edited model is in fact higher than the unedited one, which is not the case when I run train_MEND_Blip2OPT_VQA --- my original issue was that I didn't observe any improvement from the edit 😂.

Could you help check if this finding correct in your codebase? If yes, I am wondering if your results are in fact based on the editor trained with train_MEND_Blip2OPT_Caption?

Thanks!

tbozhong commented 8 months ago

I am currently working on replicating your issue, and I would like to examine your log when you execute the function train_MEND_Blip2OPT_VQA(). Interestingly, when I run this function in debug mode with a size of 20, I observe outstanding performance.

image
lliutianc commented 8 months ago

Hi, I had a try with using size=20 and my result is shown below:

03/05/2024 22:33:58 - INFO - easyeditor.trainer.BaseTrainer -   Step 100:
2024-03-05 22:33:58,474 - INFO - Step 100: (BaseTrainer.py:144)
03/05/2024 22:33:58 - INFO - easyeditor.trainer.BaseTrainer -   loss/edit_train:  32.98687; loss/image_edit_train:  33.42197; loss/loc_train:  0.24394; edit/acc_train:  0.11000; edit/log_prob_train: -32.98687; edit/prob_train:  0.10142; inner/acc_train:  0.11000; image_rephrase/acc_train:  0.11000; time/edit_train:  0.45872; loc/acc_train:  0.81520; image_loc/acc_train:  0.31451; loss/total_train:  7.19560; loss/total_edit_train:  7.19560; memory/alloc_max_train:  20027816181.76000; memory/res_max_train:  21994196172.80000; grad_train:  23482.68172; lr/lr0_train:  0.00010; lr/lr1_train:  0.00008; lr/lr2_train:  0.00009; lr/lr3_train:  0.00008; lr/lr4_train:  0.00008; lr/lr5_train:  0.00008
2024-03-05 22:33:58,673 - INFO - loss/edit_train:  32.98687; loss/image_edit_train:  33.42197; loss/loc_train:  0.24394; edit/acc_train:  0.11000; edit/log_prob_train: -32.98687; edit/prob_train:  0.10142; inner/acc_train:  0.11000; image_rephrase/acc_train:  0.11000; time/edit_train:  0.45872; loc/acc_train:  0.81520; image_loc/acc_train:  0.31451; loss/total_train:  7.19560; loss/total_edit_train:  7.19560; memory/alloc_max_train:  20027816181.76000; memory/res_max_train:  21994196172.80000; grad_train:  23482.68172; lr/lr0_train:  0.00010; lr/lr1_train:  0.00008; lr/lr2_train:  0.00009; lr/lr3_train:  0.00008; lr/lr4_train:  0.00008; lr/lr5_train:  0.00008 (BaseTrainer.py:145)
a03/05/2024 22:35:59 - INFO - easyeditor.trainer.BaseTrainer -   Step 200:
2024-03-05 22:35:59,758 - INFO - Step 200: (BaseTrainer.py:144)
03/05/2024 22:35:59 - INFO - easyeditor.trainer.BaseTrainer -   loss/edit_train:  1.36776; loss/image_edit_train:  1.28686; loss/loc_train:  0.11442; edit/acc_train:  0.88000; edit/log_prob_train: -1.36776; edit/prob_train:  0.80358; inner/acc_train:  0.84000; image_rephrase/acc_train:  0.86000; time/edit_train:  0.41939; loc/acc_train:  0.91114; image_loc/acc_train:  0.46590; loss/total_train:  0.51579; loss/total_edit_train:  0.51579; memory/alloc_max_train:  20138626048.00000; memory/res_max_train:  22053650432.00000; grad_train:  14272.82384; lr/lr0_train:  0.00009; lr/lr1_train:  0.00003; lr/lr2_train:  0.00007; lr/lr3_train:  0.00004; lr/lr4_train:  0.00006; lr/lr5_train:  0.00005
2024-03-05 22:35:59,758 - INFO - loss/edit_train:  1.36776; loss/image_edit_train:  1.28686; loss/loc_train:  0.11442; edit/acc_train:  0.88000; edit/log_prob_train: -1.36776; edit/prob_train:  0.80358; inner/acc_train:  0.84000; image_rephrase/acc_train:  0.86000; time/edit_train:  0.41939; loc/acc_train:  0.91114; image_loc/acc_train:  0.46590; loss/total_train:  0.51579; loss/total_edit_train:  0.51579; memory/alloc_max_train:  20138626048.00000; memory/res_max_train:  22053650432.00000; grad_train:  14272.82384; lr/lr0_train:  0.00009; lr/lr1_train:  0.00003; lr/lr2_train:  0.00007; lr/lr3_train:  0.00004; lr/lr4_train:  0.00006; lr/lr5_train:  0.00005 (BaseTrainer.py:145)

03/05/2024 22:38:01 - INFO - easyeditor.trainer.BaseTrainer -   Step 300:
2024-03-05 22:38:01,604 - INFO - Step 300: (BaseTrainer.py:144)
03/05/2024 22:38:01 - INFO - easyeditor.trainer.BaseTrainer -   loss/edit_train:  0.02751; loss/image_edit_train:  0.02728; loss/loc_train:  0.01199; edit/acc_train:  1.00000; edit/log_prob_train: -0.02751; edit/prob_train:  0.97755; inner/acc_train:  1.00000; image_rephrase/acc_train:  1.00000; time/edit_train:  0.42494; loc/acc_train:  0.96933; image_loc/acc_train:  0.63505; loss/total_train:  0.03511; loss/total_edit_train:  0.03511; memory/alloc_max_train:  20138626048.00000; memory/res_max_train:  22053650432.00000; grad_train:  2852.22145; lr/lr0_train:  0.00007; lr/lr1_train:  0.00001; lr/lr2_train:  0.00005; lr/lr3_train:  0.00002; lr/lr4_train:  0.00005; lr/lr5_train:  0.00003
2024-03-05 22:38:01,604 - INFO - loss/edit_train:  0.02751; loss/image_edit_train:  0.02728; loss/loc_train:  0.01199; edit/acc_train:  1.00000; edit/log_prob_train: -0.02751; edit/prob_train:  0.97755; inner/acc_train:  1.00000; image_rephrase/acc_train:  1.00000; time/edit_train:  0.42494; loc/acc_train:  0.96933; image_loc/acc_train:  0.63505; loss/total_train:  0.03511; loss/total_edit_train:  0.03511; memory/alloc_max_train:  20138626048.00000; memory/res_max_train:  22053650432.00000; grad_train:  2852.22145; lr/lr0_train:  0.00007; lr/lr1_train:  0.00001; lr/lr2_train:  0.00005; lr/lr3_train:  0.00002; lr/lr4_train:  0.00005; lr/lr5_train:  0.00003 (BaseTrainer.py:145)

It seems that my loss is also reducing, while being larger than yours.

Update: my loss on the full dataset is as follows.

03/06/2024 00:17:50 - INFO - easyeditor.trainer.BaseTrainer -   Step 100:
2024-03-06 00:17:50,614 - INFO - Step 100: (BaseTrainer.py:144)
03/06/2024 00:17:50 - INFO - easyeditor.trainer.BaseTrainer -   loss/edit_train:  48.06095; loss/image_edit_train:  47.90444; loss/loc_train:  0.61489; edit/acc_train:  0.00000; edit/log_prob_train: -48.06095; edit/prob_train:  0.00000; inner/acc_train:  0.00000; image_rephrase/acc_train:  0.00000; time/edit_train:  0.52378; loc/acc_train:  0.76714; image_loc/acc_train:  0.33986; loss/total_train:  10.86272; loss/total_edit_train:  10.86272; memory/alloc_max_train:  20080481100.80000; memory/res_max_train:  22131224084.48000; grad_train:  29022.18234; lr/lr0_train:  0.00010; lr/lr1_train:  0.00008; lr/lr2_train:  0.00008; lr/lr3_train:  0.00008; lr/lr4_train:  0.00008; lr/lr5_train:  0.00008
2024-03-06 00:17:50,614 - INFO - loss/edit_train:  48.06095; loss/image_edit_train:  47.90444; loss/loc_train:  0.61489; edit/acc_train:  0.00000; edit/log_prob_train: -48.06095; edit/prob_train:  0.00000; inner/acc_train:  0.00000; image_rephrase/acc_train:  0.00000; time/edit_train:  0.52378; loc/acc_train:  0.76714; image_loc/acc_train:  0.33986; loss/total_train:  10.86272; loss/total_edit_train:  10.86272; memory/alloc_max_train:  20080481100.80000; memory/res_max_train:  22131224084.48000; grad_train:  29022.18234; lr/lr0_train:  0.00010; lr/lr1_train:  0.00008; lr/lr2_train:  0.00008; lr/lr3_train:  0.00008; lr/lr4_train:  0.00008; lr/lr5_train:  0.00008 (BaseTrainer.py:145)
03/06/2024 00:20:18 - INFO - easyeditor.trainer.BaseTrainer -   Step 200:
2024-03-06 00:20:18,523 - INFO - Step 200: (BaseTrainer.py:144)
03/06/2024 00:20:18 - INFO - easyeditor.trainer.BaseTrainer -   loss/edit_train:  11.21624; loss/image_edit_train:  10.60467; loss/loc_train:  0.16662; edit/acc_train:  0.30000; edit/log_prob_train: -11.21624; edit/prob_train:  0.25262; inner/acc_train:  0.31000; image_rephrase/acc_train:  0.31000; time/edit_train:  0.51375; loc/acc_train:  0.80211; image_loc/acc_train:  0.28828; loss/total_train:  2.59925; loss/total_edit_train:  2.59925; memory/alloc_max_train:  20137420800.00000; memory/res_max_train:  22223519744.00000; grad_train:  27337.36012; lr/lr0_train:  0.00008; lr/lr1_train:  0.00004; lr/lr2_train:  0.00005; lr/lr3_train:  0.00003; lr/lr4_train:  0.00005; lr/lr5_train:  0.00004
2024-03-06 00:20:18,523 - INFO - loss/edit_train:  11.21624; loss/image_edit_train:  10.60467; loss/loc_train:  0.16662; edit/acc_train:  0.30000; edit/log_prob_train: -11.21624; edit/prob_train:  0.25262; inner/acc_train:  0.31000; image_rephrase/acc_train:  0.31000; time/edit_train:  0.51375; loc/acc_train:  0.80211; image_loc/acc_train:  0.28828; loss/total_train:  2.59925; loss/total_edit_train:  2.59925; memory/alloc_max_train:  20137420800.00000; memory/res_max_train:  22223519744.00000; grad_train:  27337.36012; lr/lr0_train:  0.00008; lr/lr1_train:  0.00004; lr/lr2_train:  0.00005; lr/lr3_train:  0.00003; lr/lr4_train:  0.00005; lr/lr5_train:  0.00004 (BaseTrainer.py:145)
03/06/2024 00:22:48 - INFO - easyeditor.trainer.BaseTrainer -   Step 300:
2024-03-06 00:22:48,031 - INFO - Step 300: (BaseTrainer.py:144)
03/06/2024 00:22:48 - INFO - easyeditor.trainer.BaseTrainer -   loss/edit_train:  4.21999; loss/image_edit_train:  4.12660; loss/loc_train:  0.10113; edit/acc_train:  0.54000; edit/log_prob_train: -4.21999; edit/prob_train:  0.42352; inner/acc_train:  0.55000; image_rephrase/acc_train:  0.53000; time/edit_train:  0.51983; loc/acc_train:  0.87078; image_loc/acc_train:  0.38218; loss/total_train:  1.17352; loss/total_edit_train:  1.17352; memory/alloc_max_train:  20137420800.00000; memory/res_max_train:  22223519744.00000; grad_train:  41533.94309; lr/lr0_train:  0.00008; lr/lr1_train:  0.00002; lr/lr2_train:  0.00004; lr/lr3_train:  0.00001; lr/lr4_train:  0.00004; lr/lr5_train:  0.00002
2024-03-06 00:22:48,031 - INFO - loss/edit_train:  4.21999; loss/image_edit_train:  4.12660; loss/loc_train:  0.10113; edit/acc_train:  0.54000; edit/log_prob_train: -4.21999; edit/prob_train:  0.42352; inner/acc_train:  0.55000; image_rephrase/acc_train:  0.53000; time/edit_train:  0.51983; loc/acc_train:  0.87078; image_loc/acc_train:  0.38218; loss/total_train:  1.17352; loss/total_edit_train:  1.17352; memory/alloc_max_train:  20137420800.00000; memory/res_max_train:  22223519744.00000; grad_train:  41533.94309; lr/lr0_train:  0.00008; lr/lr1_train:  0.00002; lr/lr2_train:  0.00004; lr/lr3_train:  0.00001; lr/lr4_train:  0.00004; lr/lr5_train:  0.00002 (BaseTrainer.py:145)

Could you help check with yours? Is this result look reasonable?

yaohui120 commented 8 months ago

In my opinion, function edit_dataset() is used to test IKE method. When testing with trainable methods, you can try like this:

def test_SERAC_MiniGPT4_VQA():
    hparams = SERACMultimodalTrainingHparams.from_hparams('hparams/TRAINING/SERAC/minigpt4.yaml')

    eval_ds = VQADataset('./data/MMEdit/vqa_eval.json', config=hparams)
    trainer = MultimodalTrainer(
        config=hparams,
        train_set=eval_ds,
        val_set=eval_ds
    )

    val_steps = len(eval_ds._data)
    val_info = trainer.validate(log=True)
    trainer.echo(val_steps, val_info, pretty=True)

Remember changing the ''archive'' in .yaml file with the path of saved model weights.

lliutianc commented 8 months ago

Hi, thanks for your thought. But I don't think edit_dataset() method is tied to IKE. This method is pretty similar to edit() expect that in the latter the provided edit samples are wrapped into requests that can be manipulated in a similar way to the ds in edit_dataset(). The train_ds is only used when the editor is IKE.

Below is the official demonstration to do a single knowledge edit:

metrics, edited_model, _ = editor.edit(
    prompts=prompts,
    target_new=target_new,
    image=image,
    locality_inputs=locality_inputs,
    keep_original_weight=False
)
## metrics: edit success, rephrase success, locality e.g.
## edited_model: post-edit model

IMO separating edit from trainer is more reasonable, as MEND allows one to edit the model with new knowledge during the inference time once the hypernet, which seeks to infer the final desired updates on the model from current gradient and input data, is trained.

lliutianc commented 8 months ago

Hi I've tested the overall pipeline of training MEND for Blip2 on the eval_ds, which is smaller. Below is the output from the final steps of this:

2024-03-08 13:41:22,315 - INFO - Step 1900/2093            outer_acc: 1.00000      image_acc: 1.00000      inner_acc: 0.99947      it_time: 1.4234 loc_acc: 0.99776     , image_lo
c: 0.97104      (MultimodalTrainer.py:221)
03/08/2024 13:43:50 - INFO - easyeditor.trainer.MultimodalTrainer -   Step 2000/2093            outer_acc: 1.00000      image_acc: 1.00000      inner_acc: 0.99950      it_time: 1
.4261 loc_acc: 0.99783     , image_loc: 0.97107
2024-03-08 13:43:50,086 - INFO - Step 2000/2093            outer_acc: 1.00000      image_acc: 1.00000      inner_acc: 0.99950      it_time: 1.4261 loc_acc: 0.99783     , image_lo
c: 0.97107      (MultimodalTrainer.py:221)
03/08/2024 13:46:04 - INFO - easyeditor.trainer.MultimodalTrainer -   Step 2093/2093            outer_acc: 0.99952      image_acc: 0.99952      inner_acc: 0.99904      it_time: 1
.4269 loc_acc: 0.99793     , image_loc: 0.97098
2024-03-08 13:46:04,380 - INFO - Step 2093/2093            outer_acc: 0.99952      image_acc: 0.99952      inner_acc: 0.99904      it_time: 1.4269 loc_acc: 0.99793     , image_lo
c: 0.97098      (MultimodalTrainer.py:221)
03/08/2024 13:46:04 - INFO - easyeditor.trainer.BaseTrainer -   Step 50000:
2024-03-08 13:46:04,469 - INFO - Step 50000: (BaseTrainer.py:144)
03/08/2024 13:46:04 - INFO - easyeditor.trainer.BaseTrainer -   loss/edit_val       :  0.00365
loss/image_edit_val :  0.00245
loss/loc_val        :  0.00283
edit/acc_val        :  0.99952
edit/log_prob_val   : -0.00365
edit/prob_val       :  0.99929
inner/acc_val       :  0.99904
image_rephrase/acc_val:  0.99952
time/edit_val       :  0.52041
loc/acc_val         :  0.99793
image_loc/acc_val   :  0.97098
loss/total_val      :  0.00368
loss/total_edit_val :  0.00368
memory/alloc_max_val:  21567038921.69326
memory/res_max_val  :  23735566336.00000
eval_time/elapsed   :  2986.53649
eval_time/average   :  1.42692
2024-03-08 13:46:04,471 - INFO - loss/edit_val       :  0.00365
loss/image_edit_val :  0.00245
loss/loc_val        :  0.00283
edit/acc_val        :  0.99952
edit/log_prob_val   : -0.00365
edit/prob_val       :  0.99929
inner/acc_val       :  0.99904
image_rephrase/acc_val:  0.99952
time/edit_val       :  0.52041
loc/acc_val         :  0.99793
image_loc/acc_val   :  0.97098
loss/total_val      :  0.00368
loss/total_edit_val :  0.00368
memory/alloc_max_val:  21567038921.69326
memory/res_max_val  :  23735566336.00000
eval_time/elapsed   :  2986.53649
eval_time/average   :  1.42692 (BaseTrainer.py:145)

These logs look good. But when I tested the trained MEND with edit_dataset run on the eval_ds again and the output is as follows:

rewrite_acc: 0.08818738309179718
rephrase_acc: 0.08638830229504736
rephrase_image_acc: 0.07346370083631878
locality_acc: 0.9977777687991087
multimodal_locality_acc: 0.9983516483516484
=====
rewrite_acc: 0.08282237197728574
rephrase_acc: 0.08270349501281962
rephrase_image_acc: 0.06986364317151046

My complete code is as follows:

def print_result(metrics):
    rewrite_acc = mean([m['post']['rewrite_acc'].item() for m in metrics])
    rephrase_acc = mean([m['post']['rephrase_acc'].item() for m in metrics])
    rephrase_image_acc = mean([m['post']['rephrase_image_acc'].item() for m in metrics])
    locality_acc = mean([m['post']['locality_acc'].item() for m in metrics])
    locality_image_acc = mean([m['post']['multimodal_locality_acc'].item() for m in metrics])

    print(f'rewrite_acc: {rewrite_acc}')
    print(f'rephrase_acc: {rephrase_acc}')
    print(f'rephrase_image_acc: {rephrase_image_acc}')
    print(f'locality_acc: {locality_acc}')
    print(f'multimodal_locality_acc: {locality_image_acc}')

def print_result_pre(metrics):
    rewrite_acc = mean([m['pre']['rewrite_acc'].item() for m in metrics])
    rephrase_acc = mean([m['pre']['rephrase_acc'].item() for m in metrics])
    rephrase_image_acc = mean([m['pre']['rephrase_image_acc'].item() for m in metrics])

    print(f'rewrite_acc: {rewrite_acc}')
    print(f'rephrase_acc: {rephrase_acc}')
    print(f'rephrase_image_acc: {rephrase_image_acc}')

def train_MEND_Blip2OPT_VQA(size=None):

    hparams = MENDMultimodalTrainingHparams.from_hparams('hparams/TRAINING/MEND/blip2_local.yaml')

    train_ds = VQADataset('data/vqa/vqa_train.json', config=hparams, size=size)
    eval_ds = VQADataset('data/vqa/vqa_eval.json', config=hparams, size=size)

    trainer = MultimodalTrainer(
        config=hparams,
        train_set=eval_ds,
        val_set=eval_ds,
    )

    trainer.run()   

def edit_MEND_Blip2OPT_VQA(size=None):
    hparams = MENDMultimodalHparams.from_hparams('hparams/MEND/blip2_local.yaml')

    train_ds = VQADataset('data/vqa/vqa_train.json', config=hparams, size=size)
    eval_ds = VQADataset('data/vqa/vqa_eval.json', config=hparams, size=size)

    editor = MultimodalEditor.from_hparams(hparams)
    metrics, edited_model, _ = editor.edit_dataset(
        ds=eval_ds,
        train_ds=train_ds,
        keep_original_weight=True        
    )

    print_result(metrics)
    print("=====")
    print_result_pre(metrics)

if __name__ == "__main__":

    size = None

    train_MEND_Blip2OPT_VQA()
    edit_MEND_Blip2OPT_VQA(size=size)
tbozhong commented 8 months ago

If you have any more questions or need further assistance, I can reach out to you on WeChat using the provided username YouKn0wWho for convenient communication.