suinleelab / vit-shapley

27 stars 6 forks source link

missing 9 required positional arguments: 'explanation_location_train', ... #3

Open RuoyuChen10 opened 2 years ago

RuoyuChen10 commented 2 years ago

Hi there, thanks for your great work. When I run 1_surrogate_sanity_check.ipynb, when turning to the 5-th cell:

def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,
                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:
...

It reports bugs:

---> 80 datamodule = set_datamodule(datasets=_config["datasets"],
     81                             dataset_location=_config["dataset_location"],
     82                             transforms_train=_config["transforms_train"],
     83                             transforms_val=_config["transforms_val"],
     84                             transforms_test=_config["transforms_test"],
     85                             num_workers=_config["num_workers"],
     86                             per_gpu_batch_size=_config["per_gpu_batch_size"],
     87                             test_data_split=_config["test_data_split"])

TypeError: __init__() missing 9 required positional arguments: 'explanation_location_train', 'explanation_mask_amount_train', 'explanation_mask_ascending_train', 'explanation_location_val', 'explanation_mask_amount_val', 'explanation_mask_ascending_val', 'explanation_location_test', 'explanation_mask_amount_test', and 'explanation_mask_ascending_test'

I tried to print _config and I find:

{'stage': 'classifier',
 'wandb_project_name': 'default_wandb_project_name',
 ...
 'explanation_location_train': None,
 'explanation_mask_amount_train': None,
 'explanation_mask_ascending_train': None,
 'explanation_location_val': None,
 'explanation_mask_amount_val': None,
 'explanation_mask_ascending_val': None,
 'explanation_location_test': None,
 'explanation_mask_amount_test': None,
 'explanation_mask_ascending_test': None,
 'output_dim': 10,
 'target_type': 'multiclass',
...
 'unfreeze_after_gradual': False,
 'val_check_interval': 1.0,
 'test_only': False,
 'resume_from': None,
 'fast_dev_run': False}

these values are set as None. How can I modify the bug?

chanwkimlab commented 2 years ago

Hi, thanks for reporting the bug. I've corrected it by modifying the code slightly.

RuoyuChen10 commented 2 years ago

Hi, thanks for reporting the bug. I've corrected it by modifying the code slightly.

Thanks, however, it seems another bug appear when run cell:

def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,
                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:
    """
    Args:
        num_players: the number of players in the coalitional game
        num_mask_samples: the number of masks to generate
        paired_mask_samples: if True, the generated masks are pairs of x and 1-x.
        mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')
        random_state: random generator

    Returns:
        torch.Tensor of shape
        (num_masks, num_players) if num_masks is int
        (num_players) if num_masks is None

    """
    random_state = random_state or np.random

bug:

--> 152 datamodule.set_train_dataset()
    153 datamodule.set_val_dataset()
    154 datamodule.set_test_dataset()

File ~/data3/vit-shapley/vit_shapley/datamodules/base_datamodule.py:47, in BaseDataModule.set_train_dataset(self)
     46 def set_train_dataset(self):
---> 47     self.train_dataset = self.dataset_cls(
     48         dataset_location=self.dataset_location,
     49         transform_params=self.transforms_train,
     50         explanation_location=self.explanation_location_train,
     51         explanation_mask_amount=self.explanation_mask_amount_train,
     52         explanation_mask_ascending=self.explanation_mask_ascending_train,
     53         split="train",
     54     )
...
--> 299                                                  interpolation=InterpolationMode.BILINEAR))
    300     del transforms_params_copied['RandomResizedCrop']
    302 if "VerticalFlip" in transforms_params_copied:

NameError: name 'InterpolationMode' is not defined

and I have the same problem when I try command in training_classifier.md:

python main.py with 'stage = "classifier"' \
'wandb_project_name = "wandb_transformer_interpretability_project"' 'exp_name = "ImageNette_classifier_vit_small_patch16_224_1e-5_train"' \
env_chanwkim 'gpus_classifier=[0]' \
dataset_ImageNette \
'classifier_backbone_type = "vit_small_patch16_224"' 'classifier_download_weight = True' 'classifier_load_path = None' \
training_hyperparameters_transformer 'checkpoint_metric = "accuracy"' 'learning_rate = 1e-5'

report bug:

ERROR - ViT_shapley - Failed after 0:12:24!
Traceback (most recent call last):
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/experiment.py", line 312, in run_commandline
    return self.run(
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/experiment.py", line 276, in run
    run()
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/run.py", line 238, in __call__
    self.result = self.main_function(*args)
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/config/captured_function.py", line 42, in captured_function
    result = wrapped(*args, **kwargs)
  File "main.py", line 519, in main
    trainer.fit(model_to_train, datamodule=datamodule)
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
    self._call_and_handle_interrupt(
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1138, in _run
    self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1438, in _call_setup_hook
    self.datamodule.setup(stage=fn)
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py", line 474, in wrapped_fn
    fn(*args, **kwargs)
  File "/home/cry/data3/vit-shapley/vit_shapley/datamodules/base_datamodule.py", line 78, in setup
    self.set_train_dataset()
  File "/home/cry/data3/vit-shapley/vit_shapley/datamodules/base_datamodule.py", line 47, in set_train_dataset
    self.train_dataset = self.dataset_cls(
  File "/home/cry/data3/vit-shapley/vit_shapley/datamodules/datasets/ImageNette_dataset.py", line 17, in __init__
    super().__init__(transform_params=transform_params, explanation_location=explanation_location,
  File "/home/cry/data3/vit-shapley/vit_shapley/datamodules/datasets/base_dataset.py", line 139, in __init__
    self.transforms, self.transform_resize, self.augmentation_function = self.parse_transforms_params(
  File "/home/cry/data3/vit-shapley/vit_shapley/datamodules/datasets/base_dataset.py", line 299, in parse_transforms_params
    interpolation=InterpolationMode.BILINEAR))
NameError: name 'InterpolationMode' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 138, in <module>
    def main(_config):
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/experiment.py", line 190, in automain
    self.run_commandline()
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/experiment.py", line 347, in run_commandline
    print_filtered_stacktrace()
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/utils.py", line 493, in print_filtered_stacktrace
    print(format_filtered_stacktrace(filter_traceback), file=sys.stderr)
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/utils.py", line 528, in format_filtered_stacktrace
    return "".join(filtered_traceback_format(tb_exception))
  File "/home/cry/anaconda3/envs/shap/lib/python3.8/site-packages/sacred/utils.py", line 568, in filtered_traceback_format
    current_tb = tb_exception.exc_traceback
AttributeError: 'TracebackException' object has no attribute 'exc_traceback'
RuoyuChen10 commented 2 years ago

Hi, can you help debug this error? thanks.

chanwkimlab commented 2 years ago

I've found that this error occurs when the torchvision package is outdated. Which version of torchvision are you using? I am using 0.11.3.

RuoyuChen10 commented 2 years ago

I've found that this error occurs when the torchvision package is outdated. Which version of torchvision are you using? I am using 0.11.3.

Hi, I use pytorch==1.7.1 and torchvision==0.8.2, which version do you recommend?

chanwkimlab commented 2 years ago

I am not sure when the interpolationMode interface was added to the torchvision package. I'd recommend you to use torchvision released later than mine.