TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
27 stars 8 forks source link

API Design Proposal for LiSSA algorithm ihvp #22

Closed sx-liu closed 4 months ago

sx-liu commented 5 months ago

Background

With the hvp functions implemented before, further complete the LiSSA algorithm for ihvp calculation. Compared with CG, LiSSA algorithm reduces the number of hvp calculation and is more suitable for large datasets.

Algorithm description

LiSSA algorithm approximates the ihvp function by averaging multiple samples. The samples are estimated by recursion based on Taylor expansion.

API Design

The API design basically follows the existing implementations of hvp and ihvp_cg. There are also two versions which calculate the hvp for fixed or non-fixed x's.

The first input of the LiSSA algorithms is a function for estimation, such as $L(\cdot, \cdot)$. The input list should be a list of the form $[(z_0, \theta), \dots, (z_n,\ theta)]$.

def ihvp_lissa(func: Callable,
               argnums: int = 0,
               num_repeat: int = 10, 
               recursion_depth: int = 5000,
               mode: str = "rev-rev") -> Callable:
    """LiSSA ihvp algorithm function.

    Standing for the inverse-hessian-vector product, returns a function that,
    when given vectors, computes the product of inverse-hessian and vector.

    LiSSA algorithm approximates the ihvp function by averaging multiple samples.
    The samples are estimated by recursion based on Taylor expansion.

    Args:
        func (Callable): A Python function that takes one or more arguments.
            Must return a single-element Tensor. The hessian will
            be estimated on this function.
        argnums (int): An integer default to 0. Specifies which argument of func
            to compute inverse hessian with respect to.
        num_repeat (int): An integer default 10. Specifies the number of samples
            of the hvp approximation to average on.
        recursion_depth (int): A integer default to 5000. Specifies the number of
            recursions used to estimate each ihvp sample.
        mode (str): The auto diff mode, which can have one of the following values:
            - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
                       better compatibility while cost more memory.
            - rev-fwd: calculate the hessian with the composing of reverse-mode and
                       forward-mode. It's more memory-efficient but may not be supported
                       by some operator.

    Returns:
        A function that takes a list of tuples of Tensor `x` and a vector `v` and returns
        the IHVP of the Hessian of `func` and `v`.
    """
def ihvp_at_x_lissa(func: Callable,
                    input_list: List[Tuple],
                    argnums: int = 0,
                    num_repeat: int =10, 
                    recursion_depth: int =5000,
                    mode: str = "rev-rev") -> Callable:
    """LiSSA ihvp algorithm function (with fixed x).

    Standing for the inverse-hessian-vector product, returns a function that,
    when given vectors, computes the product of inverse-hessian and vector.

    LiSSA algorithm approximates the ihvp function by averaging multiple samples.
    The samples are estimated by recursion based on Taylor expansion.

    Args:
        func (Callable): A Python function that takes one or more arguments.
            Must return a single-element Tensor. The hessian will
            be estimated on this function.
        input_list (List[Tuple]): List of arguments for multiple calls of `func`. Each tuple
            inside the list should be a pair of valid arguments
        argnums (int): An integer default to 0. Specifies which argument of func
            to compute inverse hessian with respect to.
        num_repeat (int): An integer default 10. Specifies the number of samples
            of the hvp approximation to average on.
        recursion_depth (int): A integer default to 5000. Specifies the number of
            recursions used to estimate each ihvp sample.
        mode (str): The auto diff mode, which can have one of the following values:
            - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
                       better compatibility while cost more memory.
            - rev-fwd: calculate the hessian with the composing of reverse-mode and
                       forward-mode. It's more memory-efficient but may not be supported
                       by some operator.

    Returns:
        A function that takes a vector `v` and returns the IHVP of the Hessian
        of `func` and `v`.
    """

Demonstration

def mse_loss(x, y, theta):
    return (x * theta - y)**2

ihvp = ihvp_lissa(mse_loss, argnums=2)

theta = torch.randn(1)
input_list = [
    (torch.randn(1), torch.randn(1), theta),
    (torch.randn(1), torch.randn(1), theta)
]
vec = torch.randn(5, 2)
ihvp(input_list, vec)
TheaperDeng commented 5 months ago

Could you provide more detail about the difference between num_samples and recursion_depth? I guess they are r and t in https://arxiv.org/pdf/1703.04730.pdf right? if so, maybe we can rename num_samples -> num_repeat and explain it as "Specifies the number of repeat of the procedure to average on."

minor issue:

  1. LISSA -> LiSSA

Others LGTM

sx-liu commented 5 months ago

I see. Thanks for the advice!

sx-liu commented 5 months ago

@TheaperDeng So how should we deal with hyperparameters, i.e. the damping and scaling factors?

TheaperDeng commented 5 months ago

@TheaperDeng So how should we deal with hyperparameters, i.e. the damping and scaling factors?

I think you can currently make these 2 as TODO and fix them to the most common default value

sx-liu commented 5 months ago

@TheaperDeng Another minor concern. If I understand it correctly, it seems in the former implementations, the ihvp is given by $v \cdot H^{-1}$, which means the product is from the left?