suinleelab / vit-shapley

27 stars 6 forks source link

Training SHAP on ViT for custom dataset #9

Open mdabedr opened 6 months ago

mdabedr commented 6 months ago

Hello, could you please provide some guidelines on how to obtain SHAP values for a finetuned vision transformer for custom dataset?

I am finetuning a google/vit-base-patch16-224-in21k with a classifier head on my own dataset. How can I get Shapley values with it?

chanwkimlab commented 6 months ago

Hi, thanks for your interest in our work. Once you get your fine-tuned ViT classifier, the next step is to train surrogate model where your ViT model is finetuned with random masking so that it can acommodate held-out image patches. The final step is to train explainer model using our custom loss function. The scripts for each step are available here.

MirekJara commented 1 month ago

Hi @chanwkimlab,

I'm currently trying to use scripts for training surrogate model. Based on this lines of code in main.py (lines 63-70):

    if datasets == "MURA":
        datamodule = MURADataModule(**dataset_parameters)
    elif datasets == "ImageNette":
        datamodule = ImageNetteDataModule(**dataset_parameters)
    elif datasets == "Pet":
        datamodule = PetDataModule(**dataset_parameters)
    else:
        ValueError("Invalid 'datasets' configuration")

I asssume that i need to implement Dataset and Datamodule classes for my own datasets. Is that right or is there some more straightforward way to do this? If you'd known about some repository that uses that in such a way, that would also be a huge help.

Anyway thanks in advance!

chanwkimlab commented 1 month ago

You may need to slightly modify the dataset implementation to fit your data, as the current ViT Shapley implementation expects a specific output format for the __getitem__ function :{"images": img, "labels": label, "path": img_path}. https://github.com/suinleelab/vit-shapley/blob/master/vit_shapley/datamodules/datasets/base_dataset.py#L233