TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
28 stars 8 forks source link

[dattri.model_utils] Add a wrapper for register_forward_hook #50

Closed tingwl0122 closed 4 months ago

tingwl0122 commented 4 months ago

Description

This PR implemented a wrapper function to enable register_forward_hook on users' model (at a specific location).

1. Motivation and Context

For the Representer Point Selection algorithm, we will want to use the last intermediate layer's feature to compute the representer score for explainability. This hook can help users get the desired layer's feature (not restricted to the last intermediate) after forward-passing their data to their model once.

2. Summary of the change

3. What tests have been added/updated for the change?

tingwl0122 commented 4 months ago

Hi @TheaperDeng, please have a look if you have time.

tingwl0122 commented 4 months ago

The hook function will be deprecated in the #58. The current get_layer_feature will extract the output of the second-to-last layer, which ignores possible activation functions applied afterward. Hence, we just directly obtain the input of the last layer.