berenslab / retinal_image_counterfactuals

Realistic retinal fundus and OCT counterfactuals using diffusion models and classifiers.
2 stars 0 forks source link

Training of Robustness Classifier and Diffusion models #1

Open gabshi59 opened 4 weeks ago

gabshi59 commented 4 weeks ago

Hi Indu,

Thank you for your wonderful work! This work is quite interesting to me and I think the results are amazing. However, I was confused when I tried applying this method to my own dataset. I do not know how to fine-tune the diffusion models on my dataset. And I am also a little confused about how to train the robustness classifier. I would greatly appreciate it if you could provide more relevant information. Thank you so much!

induilanchezian commented 5 days ago

Hi gabshi59,

Thank you for showing interest in our work and reaching out!

For training diffusion models, we used openai’s guided diffusion repository: https://github.com/openai/guided-diffusion. You may have to find parameters that work best for your dataset. We used the parameters specified for the 256x256 unconditional case with a linear noise schedule i.e.

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"

The command for training the models can be found in https://github.com/openai/improved-diffusion:

python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS

For adversarially robust classifier training, we used the TRADES algorithm. The code for training TRADES models is within the current repository under counterfactual_utils/train_types/TRADES_training.py. Note that training robust classifier involves generation of adversarial examples within the training loop. The type of adversarial attack can be set using the attack_config parameter of TRADESTraining object. An attack_config can be created using the function create_attack_config in counterfactual_utils/train_types/helpers.py.

Hope that this clarifies the training procedures :)

gabshi59 commented 2 days ago

Hi gabshi59,

Thank you for showing interest in our work and reaching out!

For training diffusion models, we used openai’s guided diffusion repository: https://github.com/openai/guided-diffusion. You may have to find parameters that work best for your dataset. We used the parameters specified for the 256x256 unconditional case with a linear noise schedule i.e.

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"

The command for training the models can be found in https://github.com/openai/improved-diffusion:

python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS

For adversarially robust classifier training, we used the TRADES algorithm. The code for training TRADES models is within the current repository under counterfactual_utils/train_types/TRADES_training.py. Note that training robust classifier involves generation of adversarial examples within the training loop. The type of adversarial attack can be set using the attack_config parameter of TRADESTraining object. An attack_config can be created using the function create_attack_config in counterfactual_utils/train_types/helpers.py.

Hope that this clarifies the training procedures :)

Thank you so much for your reply! These are very helpful to us! Thanks again for your great work and we will definitely cite your paper if we want to apply it as a baseline.