ys-zong / MEDFAIR

[ICLR 2023 spotlight] MEDFAIR: Benchmarking Fairness for Medical Imaging
https://ys-zong.github.io/MEDFAIR/
56 stars 10 forks source link

How to perform preprocessing for PAPILLA and the training and testing? #3

Closed pearlmary closed 1 year ago

ys-zong commented 1 year ago

Hi,

Thanks for your interest. You can follow a similar preprocessing process as in HAM10000-example.ipynb. All the necessary details can be found in Appendix B.1.2 (paper). I will also aim to update the preprocessing code for all datasets soon.

For training and testing, you can simply pass the argument --dataset_name PAPILA along with other arguments for python main.py once you finish the preprocessing.

Feel free to reopen it should you have any problems.

pearlmary commented 1 year ago

Hi Zong. Thank you for the reply. I've completed preprocessing for papila dataset and fixed few bugs in some of the debiasing algorithms. But I'm stuck with the errors from 3 debiasing algorithms. Three errors while running single experiments for LAFTR, SWA, CFair. (Need help in resolving this error) LAFTR: (fairmed) root@537614b35cbf:/workspace/MEDFAIR# python main.py --experiment LAFTR --experiment_name test --dataset_name PAPILA --backbone cusResNet18 --total_epochs 10 --sensitive_name Age --batch_size 1024 --lr 0.01 --sens_classes 2 --val_strategy loss --output_dim 1 --num_classes 1 run hash (first 10 digits): fb24e9a7d2 Random seed: [55, 88, 65] loaded dataset PAPILA Training epoch 0: AUC:0.5051697530864196 Training epoch 0: cls loss:0.5227768421173096, adv loss:-0.17878487706184387 Traceback (most recent call last): File "main.py", line 34, in pred_df = train(model, opt) File "main.py", line 13, in train ifbreak = model.train(epoch) File "/workspace/MEDFAIR/models/basenet.py", line 162, in train val_loss, val_auc, log_dict, pred_df = self._val(self.val_loader) File "/workspace/MEDFAIR/models/LAFTR/LAFTR.py", line 160, in _val auc, val_loss, log_dict, pred_df = standard_val(self.opt, self.network, loader, self._criterion, self.bianry_train_multi_test, self.wandb) File "/workspace/MEDFAIR/models/utils.py", line 48, in standard_val outputs, features = network.forward(images) TypeError: forward() missing 1 required positional argument: 'Y'

SWA: (fairmed) root@537614b35cbf:/workspace/MEDFAIR# python main.py --experiment SWA --experiment_name test --dataset_name PAPILA --backbone cusResNet18 --total_epochs 10 --sensitive_name Age --batch_size 1024 --lr 0.01 --sens_classes 2 --val_strategy loss --output_dim 1 --num_classes 1 run hash (first 10 digits): a748e06554 Random seed: [25, 49, 16] loaded dataset PAPILA Training epoch 0: AUC:40.733024691358025 Training epoch 0: loss:1.184382677078247 Validation epoch 1: validation loss:0.5855087041854858, AUC:49.853372434017594 saving best model in epoch 0 in your_path/fariness_data/model_records/PAPILA/Age/cusResNet18/SWA/25_best.pth Finish training epoch 1, Val AUC: 49.853372434017594, time used: 0:00:04.349343 Training epoch 1: AUC:48.773148148148145 Training epoch 1: loss:0.497341126203537 Validation epoch 2: validation loss:0.5802542567253113, AUC:52.7859237536657 saving best model in epoch 1 in your_path/fariness_data/model_records/PAPILA/Age/cusResNet18/SWA/25_best.pth Finish training epoch 2, Val AUC: 52.7859237536657, time used: 0:00:02.775651 Training epoch 2: AUC:51.018518518518505 Training epoch 2: loss:0.4866572618484497 Validation epoch 3: validation loss:0.5822473168373108, AUC:53.3724340175953 Finish training epoch 3, Val AUC: 53.3724340175953, time used: 0:00:02.745621 Training epoch 3: AUC:57.45370370370371 Training epoch 3: loss:0.4754384160041809 Validation epoch 4: validation loss:0.5816969275474548, AUC:56.01173020527859 Finish training epoch 4, Val AUC: 56.01173020527859, time used: 0:00:02.561891 Training epoch 4: AUC:54.85339506172839 Training epoch 4: loss:0.4786106050014496 Validation epoch 5: validation loss:0.5822972655296326, AUC:56.598240469208214 Finish training epoch 5, Val AUC: 56.598240469208214, time used: 0:00:02.677080 Training epoch 5: AUC:58.74228395061728 Training epoch 5: loss:0.4696345329284668 Validation epoch 6: validation loss:0.5794610381126404, AUC:57.771260997067444 saving best model in epoch 5 in your_path/fariness_data/model_records/PAPILA/Age/cusResNet18/SWA/25_best.pth Finish training epoch 6, Val AUC: 57.771260997067444, time used: 0:00:02.590286 Training epoch 6: AUC:66.34259259259258 Training epoch 6: loss:0.45700129866600037 Validation epoch 7: validation loss:0.5797523260116577, AUC:58.65102639296188 Finish training epoch 7, Val AUC: 58.65102639296188, time used: 0:00:02.527177 Training epoch 7: AUC:69.84567901234568 Training epoch 7: loss:0.4469977021217346 Validation epoch 8: validation loss:0.5797624588012695, AUC:57.77126099706745 Finish training epoch 8, Val AUC: 57.77126099706745, time used: 0:00:02.448133 Training epoch 8: AUC:70.16975308641975 Training epoch 8: loss:0.4463978409767151 Validation epoch 9: validation loss:0.5803178548812866, AUC:58.65102639296187 Finish training epoch 9, Val AUC: 58.65102639296187, time used: 0:00:02.214667 Training epoch 9: AUC:72.85493827160494 Training epoch 9: loss:0.44254598021507263 Validation epoch 10: validation loss:0.580383837223053, AUC:56.59824046920821 Finish training epoch 10, Val AUC: 56.59824046920821, time used: 0:00:02.414403 Validation performance: {'Val Overall AUC': 0.5777126099706744, 'Val auc-group_0': 0.0625, 'Val auc-group_1': 0.6799999999999999, 'Val Overall Acc': 0.5952380952380952, 'Val acc-group_0': 0.5882352941176471, 'Val acc-group_1': 0.6, 'Val DP': 0.8729411764705883, 'Val EqOpp1': 0.975, 'Val EqOpp0': 0.4, 'Val EqOdd': 0.6875, 'Val EqOdd_0.5': 1.0, 'Val EqOdd_specificity_0.8': 0.7270833333333333, 'Val EqOdd_sensitivity_0.8': 0.5562499999999999, 'Val Overall ECE': 0.17898556625559217, 'Val Overall BCE': 0.5794610077802325, 'Val tpr_at_tnr_0': array(0.15714286), 'Val fnr_at_thres-group_0': 1.0, 'Val fpr_at_thres-group_0': 0.0, 'Val recall_at_thres-group_0': 0.0, 'Val specificity_at_thres-group_0': 1.0, 'Val ECE-group_0': 0.23644578018609214, 'Val BCE-group_0': 0.32750473417743503, 'Val tpr_at_tnr_1': array(0.5), 'Val fnr_at_thres-group_1': 1.0, 'Val fpr_at_thres-group_1': 0.0, 'Val recall_at_thres-group_1': 0.0, 'Val specificity_at_thres-group_1': 1.0, 'Val ECE-group_1': 0.24854536175727845, 'Val BCE-group_1': 0.7507912738301346, 'Val worst_auc': 0.0625, 'Val worst_group': 'auc-group_0', 'Val Overall FPR': 0.3870967741935484, 'Val Overall FNR': 0.45454545454545453, 'Val FPR-group_0': 0.375, 'Val FPR-group_1': 0.4, 'Val FNR-group_0': 1.0, 'Val FNR-group_1': 0.4} Traceback (most recent call last): File "main.py", line 37, in pred_df = model.test() File "/workspace/MEDFAIR/models/SWA/SWA.py", line 122, in test torch.optim.swa_utils.update_bn(self.loader, self.swa_model, device = self.device) File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1269, in getattr raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'SWA' object has no attribute 'loader'

CFair (fairmed) root@537614b35cbf:/workspace/MEDFAIR# python main.py --experiment CFair --experiment_name test --dataset_name PAPILA --backbone cusResNet18 --total_epochs 10 --sensitive_name Age --batch_size 1024 --lr 0.01 --sens_classes 2 --val_strategy loss --output_dim 1 --num_classes 1 run hash (first 10 digits): 6ae30fed6c Random seed: [33, 2, 82] loaded dataset PAPILA Traceback (most recent call last): File "main.py", line 34, in pred_df = train(model, opt) File "main.py", line 13, in train ifbreak = model.train(epoch) File "/workspace/MEDFAIR/models/basenet.py", line 160, in train self._train(self.train_loader) File "/workspace/MEDFAIR/models/CFair/CFair.py", line 63, in _train loss = self._criterion(ypreds, targets, pos_weight=reweight_target_tensor) TypeError: _criterion() got an unexpected keyword argument 'pos_weight'

ys-zong commented 1 year ago

Thanks for pointing out the issue. LAFTR and CFair should be fixed now. When running these two methods, please also pass the argument --bianry_train_multi_test 2. This is a suboptimal implementation. I'll try to optimize it later. Also, you can try to use Adam as the optimizer instead of SGD if you are training with small datasets. According to experiments, they lead to similar results and rankings in the end.

For SWA, as I see from your log:

File "/workspace/MEDFAIR/models/SWA/SWA.py", line 122, in test
torch.optim.swa_utils.update_bn(self.loader, self.swa_model, device = self.device)
File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1269, in getattr
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'SWA' object has no attribute 'loader'

You are passing self.loader to the function update_bn, which causes the attribute error. The current code this line is actually passing self.train_loader. Can you git pull to see if you have the latest code?

BTW, can you share what other bugs you faced so that I can also fix them for other potential users? (methods not listed in the paper are not extensively tested). Many thanks!

pearlmary commented 1 year ago

Hi Zong. Thank you for the quick reply. As you have said, the issue with LAFTR and CFair are fixed now.

But still the issue with SWA, exists though self.train_loader is passed:

Initially I had to add the missing index in line 69 in SWA.py <for i, (images, targets, sensitive_attr) in enumerate(loader):>

error log:

Traceback (most recent call last): File "main.py", line 37, in pred_df = model.test() File "/workspace/MEDFAIR/models/SWA/SWA.py", line 122, in test torch.optim.swa_utils.update_bn(self.train_loader, self.swa_model, device = self.device) File "/workspace/fairmed/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, kwargs) File "/workspace/fairmed/lib/python3.8/site-packages/torch/optim/swa_utils.py", line 187, in update_bn model(input) File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/workspace/fairmed/lib/python3.8/site-packages/torch/optim/swa_utils.py", line 116, in forward return self.module(args, kwargs) File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "/workspace/MEDFAIR/models/basemodels.py", line 22, in forward outputs = self.body(x) File "/workspace/fairmed/lib/python3.8/site-packages/torch/fx/graph_module.py", line 658, in call_wrapped return self._wrapped_call(self, *args, *kwargs) File "/workspace/fairmed/lib/python3.8/site-packages/torch/fx/graph_module.py", line 277, in call raise e File "/workspace/fairmed/lib/python3.8/site-packages/torch/fx/graph_module.py", line 267, in call return super(self.cls, obj).call(args, kwargs) # type: ignore[misc] File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File ".5", line 5, in forward File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward return self._conv_forward(input, self.weight, self.bias) File "/workspace/fairmed/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [294]**

Sure I will share few bugs I noted. Thanks.

ys-zong commented 1 year ago

Right, sorry about the confusion. The torch.optim.swa_utils.update_bn function takes the first element of the dataloader as input by default if the dataloader returns a tuple -- that's why you encounter the shape error. A quick fix is to directly take the update_bn function I rewrote for SWAD to replace the torch built-in update_bn function and add the missing index. A better implementation of this will be released soon.

pearlmary commented 1 year ago

Hi Zong. Thank you for the reply. You've solved the confusion.

Would you mind telling me what went wrong while running a grid search on a Slurm cluster. I'm so new to it.

(fairmed) root@537614b35cbf:/workspace/MEDFAIR# python sweep/train-sweep/sweep_batch.py baseline wandb: Currently logged in as: mebinjose1. Use wandb login --relogin to force relogin wandb: Tracking run with wandb version 0.13.10 wandb: Run data is saved locally in /workspace/MEDFAIR/wandb/run-20230216_190049-5b5e6wv0 wandb: Run wandb offline to turn off syncing. wandb: Syncing run smooth-oath-20 wandb: ⭐️ View project at https://wandb.ai/mebinjose1/PAPILA%20baseline wandb: 🚀 View run at https://wandb.ai/mebinjose1/PAPILA%20baseline/runs/5b5e6wv0 Create sweep with ID: c3owuw8b Sweep URL: https://wandb.ai/mebinjose1/PAPILA%20baseline/sweeps/c3owuw8b c3owuw8b command is sbatch sweep/train-sweep/sweep_count.sh --sweep_id c3owuw8b Traceback (most recent call last): File "sweep/train-sweep/sweep_batch.py", line 86, in process = subprocess.Popen(CMD, stdout=subprocess.PIPE, universal_newlines=True) File "/opt/conda/lib/python3.8/subprocess.py", line 858, in init self._execute_child(args, executable, preexec_fn, close_fds, File "/opt/conda/lib/python3.8/subprocess.py", line 1704, in _execute_child raise child_exception_type(errno_num, err_msg, err_filename) FileNotFoundError: [Errno 2] No such file or directory: 'sbatch' wandb: Waiting for W&B process to finish... (failed 1). Press Control-C to abort syncing. wandb: 🚀 View run smooth-oath-20 at: https://wandb.ai/mebinjose1/PAPILA%20baseline/runs/5b5e6wv0 wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s) wandb: Find logs at: ./wandb/run-20230216_190049-5b5e6wv0/logs

Thanks in advance.

ys-zong commented 1 year ago

Hi, I'm not so sure why this happened but I guess this is something related to your cluster environment instead of the code. I saw people also asking questions like this because they don't have a Slurm environment at all?

You can also create new issues about different problems for better organization and indexing :)

ys-zong commented 1 year ago

FYI, LAFTR and CFair do not need the --binary_train_multi_test argument, and the value of --sens_classes can be passed as other methods. Also fixed the SWA index problem.

pearlmary commented 1 year ago

Oh great! Thank You.