OPTML-Group / Unlearn-Saliency

[ICLR24 (Spotlight)] "SalUn: Empowering Machine Unlearning via Gradient-based Weight Saliency in Both Image Classification and Generation" by Chongyu Fan*, Jiancheng Liu*, Yihua Zhang, Eric Wong, Dennis Wei, Sijia Liu
https://www.optml-group.com/posts/salun_iclr24
MIT License
93 stars 12 forks source link

Cannot reproduce the results of forgetting Imagenette using Stable diffusion #20

Closed Minjong-Lee closed 6 days ago

Minjong-Lee commented 2 weeks ago

Hi, thanks for your great work. I'm trying to reproduce the results of forgetting Imagenette using SD v1.4. Thanks to the instructions you provided for reproduction, I followed them exactly as they were. Since the execution itself wasn't difficult, I don't think there were any issues during the training and evaluation. However, I couldn't reproduce the results mentioned in the paper. While some classes seemed to work, most of them did not train properly.

Could I possibly get some advice regarding this issue? I think the experimental environment settings(e.g., versions of pytorch, cuda...) might be different, and I'm wondering if this could be the cause of the problem. Thanks a lot!

a-F1 commented 2 weeks ago

Thank you for your appreciation and recognition of our work. Could you kindly provide more detailed information, such as the commands you used?

Minjong-Lee commented 2 weeks ago

Thank you for your reply!

As I said, I used the same commands you provided in this page.

For the class 0, python train-scripts/generate_mask.py --ckpt_path 'models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt' --classes '0' --device '0' python train-scripts/random_label.py --train_method full --alpha 0.5 --lr 1e-5 --epochs 5 --class_to_forget '0' --mask_path 'mask/0/with_0.5.pt' --device '0'

I didn't make changes to the code. I think the execution environment (such as CUDA, Torch, or Diffusers) could be an issue, so I have attached the environment I used. (I used CUDA v12.1.)

absl-py 2.1.0 aiohappyeyeballs 2.4.0 aiohttp 3.10.5 aiosignal 1.3.1 albumentations 0.4.3 altair 5.4.1 antlr4-python3-runtime 4.8 async-timeout 4.0.3 attrs 24.2.0 blinker 1.8.2 Brotli 1.0.9 cachetools 5.5.0 certifi 2024.7.4 charset-normalizer 3.3.2 click 8.1.7 clip 1.0 /home1/mjlee42/Unlearn-Saliency/SD/src/clip contourpy 1.1.1 cycler 0.12.1 datasets 2.21.0 diffusers 0.30.1 dill 0.3.8 einops 0.3.0 filelock 3.15.4 fonttools 4.53.1 frozenlist 1.4.1 fsspec 2024.6.1 ftfy 6.2.3 future 1.0.0 gitdb 4.0.11 GitPython 3.1.43 google-auth 2.34.0 google-auth-oauthlib 1.0.0 grpcio 1.66.0 huggingface-hub 0.24.6 idna 3.7 imageio 2.9.0 imageio-ffmpeg 0.4.2 imgaug 0.2.6 importlib-metadata 8.4.0 importlib-resources 6.4.4 invisible-watermark 0.2.0 jinja2 3.1.4 jsonschema 4.23.0 jsonschema-specifications 2023.12.1 kiwisolver 1.4.5 kornia 0.6.0 latent-diffusion 0.0.1 /home1/mjlee42/Unlearn-Saliency/SD lazy-loader 0.4 Markdown 3.7 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib 3.7.5 mdurl 0.1.2 mkl-fft 1.3.1 mkl-random 1.2.2 mkl-service 2.4.0 multidict 6.0.5 multiprocess 0.70.16 narwhals 1.5.5 networkx 3.1 numpy 1.24.4 oauthlib 3.2.2 omegaconf 2.1.1 opencv-python 4.1.2.30 opencv-python-headless 4.10.0.84 packaging 24.1 pandas 2.0.3 pillow 10.4.0 pip 20.3.3 pkgutil-resolve-name 1.3.10 protobuf 5.27.3 pudb 2019.2 pyarrow 17.0.0 pyasn1 0.6.0 pyasn1-modules 0.4.0 pydeck 0.9.1 pyDeprecate 0.3.1 pygments 2.18.0 pyparsing 3.1.4 PySocks 1.7.1 python-dateutil 2.9.0.post0 pytorch-lightning 1.4.2 pytz 2024.1 PyWavelets 1.4.1 PyYAML 6.0.2 referencing 0.35.1 regex 2024.7.24 requests 2.32.3 requests-oauthlib 2.0.0 rich 13.8.0 rpds-py 0.20.0 rsa 4.9 safetensors 0.4.4 scikit-image 0.20.0 scipy 1.9.1 setuptools 72.1.0 six 1.16.0 smmap 5.0.1 streamlit 1.37.1 taming-transformers 0.0.1 /home1/mjlee42/Unlearn-Saliency/SD/src/taming-transformers tenacity 8.5.0 tensorboard 2.14.0 tensorboard-data-server 0.7.2 test-tube 0.7.5 tifffile 2023.7.10 tokenizers 0.19.1 toml 0.10.2 torch 1.11.0 torch-fidelity 0.3.0 torchmetrics 0.6.0 torchvision 0.12.0 tornado 6.4.1 tqdm 4.66.5 transformers 4.44.2 typing-extensions 4.11.0 tzdata 2024.1 urllib3 2.2.2 urwid 2.6.15 watchdog 4.0.2 wcwidth 0.2.13 werkzeug 3.0.4 wheel 0.43.0 xxhash 3.5.0 yarl 1.9.4 zipp 3.20.1

I'll be waiting for your response. Thank you again!

anseryuer commented 1 week ago

torchmetrics 0.6.0

I believe there is a version mistake of the repo about torch metrics and in this version the metrics are like all higher and the FID values are about at 4.0 + no matter how good your model is. You can update the version to the newest torchmetrics (that will influence the training so you better create a new environment for just calculate the metrics). There is also a new parameter into the FID function at the new version called "normalized" I remember, and adding it will bring the FID metrics back to normal. I dont have my code with me right now but you can test it out. Besides the author really need to fix the problem about the FID metrics here.

Minjong-Lee commented 1 week ago

Thank you for your reply!

But, it seems that the problem still hasn't been resolved. What I wanted to say is that, even when following the given guidelines and the hyperparameter settings of the paper, I was unable to reproduce the results.

For example, after I use the following commands:

python train-scripts/generate_mask.py --ckpt_path 'models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt' --classes '0' --device '0' python train-scripts/random_label.py --train_method full --alpha 0.5 --lr 1e-5 --epochs 5 --class_to_forget '0' --mask_path 'mask/0/with_0.5.pt' --device '0'

when I tried to generate a tench, the result according to the paper was the generation of an object other than a tench. However, the actual generated image is as follows: 0_0

The evaluation commands is python eval-scripts/generate-images.py --prompts_path 'prompts/imagenette.csv' --save_path 'evaluation_folder/ --model_name {model} --device 'cuda:0'. This issue occurs across multiple classes. Therefore, I would like to ask if anyone has experienced this issue before and how it can be resolved. About the version of torchmetrics, since the generation process is not affected by FID, I believe the version of torchmetrics is not relevant (furthermore, my testing has shown that it has no impact).

I have also tried adjusting various hyperparameters, but it did not significantly resolve the issue. If it is a problem with the hyperparameters, I would appreciate it if you could suggest the appropriate ones. If it is an environmental issue, I would be grateful if you could provide the versions of the libraries used during the actual experiments.

Thank you for considering my issue!

a-F1 commented 1 week ago

Thank you for sharing the commands with us. Based on the images you provided, it seems that the unlearned model may not have been loaded correctly. Could you please check the path of the unlearned model and review the model loading process?

Minjong-Lee commented 6 days ago

Thank you for your apply!

As you mentioned, the problem was that I didn't properly specify the unlearned model path during the evaluation. The script that I used is

python eval-scripts/generate-images.py \
    --prompts_path 'prompts/imagenette.csv' \
    --save_path 'evaluation_folder/' \
    --model_name 'SD/model/compvis-cl-mask-class_0-method_full-alpha_0.5-epoch_5-lr_1e-05' \
    --device 'cuda:0'

In your code(eval-scripts/generate-images.py), the loading process of the unlearned model is

if "SD" not in model_name:
    try:
        model_path = (
            f'models/{model_name}/{model_name.replace("compvis","diffusers")}.pt'
        )

        # model_path = model_name
        unet.load_state_dict(torch.load(model_path))
    except Exception as e:
        print(
            f"Model path is not valid, please check the file name and structure: {e}"
        )

Since 'SD' was in the model_path, it didn't even enter the if statement and didn't throw any errors. (Looking at the path I used now, it seems like I set it in a very inappropriate way...) After changing 'SD/model/compvis-cl-mask-class_0-method_full-alpha_0.5-epoch_5-lr_1e-05' to 'compvis-cl-mask-class_0-method_full-alpha_0.5-epoch_5-lr_1e-05', it worked properly.

I kept thinking the training process was wrong, but looking back now, I realize I wasn't considering checking the evaluation code. This is clearly my mistake. I really appreciate your kind help so far!