TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.algorithm] Add `AttributionTask` abstraction for new API design #99

Closed TheaperDeng closed 1 month ago

TheaperDeng commented 1 month ago

Description

1. Motivation and Context

  1. Introduce a new AttributionTask class to manage model and checkpoint information, in order to reduce user's workload on extracting internal information of the models needed for some algorithms.
  2. Make 2 levels of Attributor API, which makes it easier for users to build customized attributors from our template rather than directly from scratch.
  3. Support DataInf as one of the attributors.
  4. All examples, unittests and API documents are changed accordingly.

2. Summary of the change

  1. Introduce a new AttributionTask API, so that users will follow the following template for attributor usage
    
    # User code
    model = create_model()

def f(params, data_target_pair): image, label = data_target_pair loss = nn.CrossEntropyLoss() yhat = torch.func.functional_call(model, params, image) return loss(yhat, label.long())

task = AttributionTask(target_func=f, model=model, checkpoints=model.state_dict()) attributor = Attributor(task=task, device="cuda")


This `AttributionTask` contains some useful member functions that are very helpful for attributor development.
```python
# Developer code
task.get_grad_target_func(...)  # get the targeted function's gradient func
task.get_target_func(...)  # get the targeted function, which can be directly used by dattri.func
task.get_param(...)  # get the parameters w/ or w/o layer split, as well as the layer mapping automatically (for datainf and ek-fac)
task.register_forward_hook(...)  # TODO, return the handler of the hook for hidden states.
  1. Refactors the original IFAttributors into a more flexible BaseInnerProductAttributor and some high-level implementations of Attributors.

Users who do not want to customize their own attributor may directly use the high-level implementations of Attributors.

# User code
from dattri.algorithm.influence_function import\
    IFAttributorExplicit,\
    IFAttributorCG,\
    IFAttributorLiSSA,\
    IFAttributorDataInf,\
    IFAttributorArnoldi

attributor = IFAttributorExplicit(task=task, device="cuda", regularization=0.01)
attributor = IFAttributorCG(task=task, device="cuda", regularization=0.01)
attributor = IFAttributorLiSSA(task=task, device="cuda", recursion_depth=100)
attributor = IFAttributorDataInf(task=task, device="cuda", regularization=0.01)
attributor = IFAttributorArnoldi(task=task, device="cuda", regularization=0.01)

# following usage is not changed (i.e., attributor.cache() and attributor.attribute())

Developers who want to create their own attributors may follow the templates we provided, currently we have BaseInnerProductAttributor for a large family of gradient-based TDA methods that has (g^TH^{-1}g) format.

It can also support

class MyAttributor(BaseInnerProductAttributor):
    def generate_test_query(
        self,
        index: int,
        data: Tuple[torch.Tensor, ...],
    ) -> torch.Tensor:
        # this methods defaultly calculate the gradient over test samples
        # but developers may want to have their own design, .e.g, use some projectors.
        model_params, _ = self.task.get_param(index)
        return self.task.get_grad_target_func()(model_params, data)

    def generate_train_query(
        self,
        index: int,
        data: Tuple[torch.Tensor, ...],
    ) -> torch.Tensor:
        # this methods defaultly calculate the gradient over train samples
        # but developers may want to have their own design, .e.g, use some projectors.
        model_params, _ = self.task.get_param(index)
        return self.task.get_grad_target_func()(model_params, data)

    def transformation_on_query(
        self,
        index: int,
        data: Tuple[torch.Tensor, ...],
        query: torch.Tensor,
        **transformation_kwargs
    ) -> torch.Tensor:
       # This methods is called on the test query (as the vector)
       # IF will calculate the Hessian/Fisher and make the ihvp or ifvp
       # Grad-cos/dot will directly return the test query.

       # Let's take explicitly IHVP as example
       from dattri.func.ihvp import ihvp_explicit

        self.ihvp_func = ihvp_explicit(
            partial(self.task.get_target_func(), data_target_pair=data),
            **transformation_kwargs,
        )
        model_params, _ = self.task.get_param(index)
        return self.ihvp_func((model_params,), query).detach()
  1. DataInf can be directly supported under this design, with no additional work on the users' side.

3. What tests have been added/updated for the change?

4. TODOs

Some todos will be added to the project tracking page.