TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.func] rename ihvp.py -> hessian.py/fisher.py #103

Closed TheaperDeng closed 1 month ago

TheaperDeng commented 1 month ago

Description

1. Motivation and Context

This is a renaming PR that will split the original ihvp collection to hessian and fisher separately to make them clearer in math perspective.

2. Summary of the change

  1. Rename and split dattri.func.ihvp to dattri.func.hessian and dattri.func.fisher. Together with the unittest, related import command and document.
    # dattri.func.fisher
    '''
    This module contains:
    - `ifvp_explicit`: IFVP via explicit FIM calculation.
    - `ifvp_at_x_explicit`: IFVP via explicit FIM calculation (with fixed x).
    - `ifvp_datainf`: DataInf ifvp algorithm function.
    - `ifvp_at_x_datainf`: DataInf ifvp algorithm function (with fixed x).
    - `ifvp_at_x_ekfac`: EK-FAC ifvp algorithm function (with fixed x).
    '''
    # dattri.func.hessian
    """
    This module contains:
    - `hvp`: Calculate the Hessian Vector Product (HVP) of a function.
    - `hvp_at_x`: Calculate the Hessian Vector Product (HVP) of a function with fixed x.
    - `hvp`: Calculate the HVP of a function.
    - `hvp_at_x`: Calculate the HVP of a function with fixed x.
    - `ihvp_at_x_explicit`: IHVP via explicit Hessian calculation.
    - `ihvp_cg`: Conjugate Gradient Descent ihvp algorithm function.
    - `ihvp_at_x_cg`: Conjugate Gradient Descent ihvp algorithm function with fixed x.
    - `ihvp_arnoldi`: Arnoldi Iteration ihvp algorithm function.
    - `ihvp_at_x_arnoldi`: Arnoldi Iteration ihvp algorithm function with fixed x.
    - `ihvp_lissa`: Lissa algorithm ihvp function.
    - `ihvp_at_x_lissa`: Lissa algorithm ihvp function with fixed x.
    """
  2. Add ifvp_explicit and ifvp_at_x_explicit
    • Add special description for the func (split the empirical FIM and true FIM)
    • check it can work on a trained NN with >0.99 correlation with explicit Hessian ihvp.

3. What tests have been added/updated for the change?