Closed aski02 closed 4 months ago
Forgot to also update the project dependencies to use the current captum version from GitHub. I will fix this.
Forgot to also update the project dependencies to use the current captum version from GitHub. I will fix this.
Thanks! we would need to set the captum version to the github one in the pyproject.toml, like described here: https://chrisholdgraf.com/blog/2022/install-github-from-pyproject/
Thanks @aski02! I took a quick look and everything looks very very good already! Will take a closer look tomorrow or Friday for more detailed comments 🙂
Hello @aski02
Thanks for the great work! I left a couple of comments that you can respond to. But I also think @dilyabareeva could clarify things 😄
I looked into the Captum implementation of the ArnoldiInfluenceFunction again and noticed that even tho the learning rate is mentioned in the docstrings as an optional return value, it is never actually used. Here is the corresponding line from captum.influence._core.influence_function.py
:
checkpoints_load_func(model, checkpoint)
Therefore, I do not verify if the function returns a learning rate and only check if the given parameters are correct.
If we just make the batch size parameter optional, we would still have to check whether to pass it into the self_influence function or not:
def self_influence_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
batch_size: Optional[int] = 32,
**kwargs: Any,
) -> torch.Tensor:
explainer = _init_explainer(
explainer_cls=explainer_cls,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**kwargs,
)
return explainer.self_influence(batch_size=batch_size)
A more elegant approach would likely be to pass it in kwargs - especially if other methods in the future require more parameters. Since this will require some small changes to SimilarityInfluence as well, I have not done it yet.
However, I will check the parameters required by TracInCP when implementing its wrapper and then I will see how to handle this elegantly.
Hi @aski02, thanks for this! I think we can merge.