TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
28 stars 8 forks source link

API Design Proposal for DataInf Algorithm ihvp. #28

Closed charles-pyj closed 3 months ago

charles-pyj commented 4 months ago

Background

Yongchan et. al proposed a new method for influence estimation given the loss function is negative log likelihood. Compared to other methods such as LiSSA or RandProj, DataInf has a closed form solution and is much efficient both in terms of time and memory. The method is shown to have outstanding performance for models using parameter-efficient fine-tuning techniques like LoRA and thus is useful for Large Language Models, Diffusion Models and other generative AI.

Algorithm Description

To address the issue of having too-much parameters, DataInf leverages layer-wise influence calculation. Under the assumption of NLL loss, it is able to approximate ihvp using a closed form. It can approximate the ihvp in $O(\sum_{i = 1}^{L}n * d_l)$ time, where n is the sample to calculate ihvp on, L is the total layer of the model and $d_l$ is the dimension of the lth layer. PS: This method does not need access to Hessian and only to the gradients.

API Design

The API design will be different from other ihvps as this method does not need to approximate a hessian.

def ihvp_datainf(func: Callable,
                argnums: int = 0,
                regularization: List [float] = [0.0],
                *x
                 ) -> Callable:
''' DataInf ihvp algorithm function.

    Standing for the inverse-hessian-vector product, returns a function that,
    when given vectors, computes the product of inverse-hessian and vector.

    DataInf assume the loss to be cross-entropy and thus derive a closed form
    ihvp without having to approximate the hessian.

    Args:
        func (Callable): A Python function that takes one or more arguments.
            Must return a list of dictionary of gradients. For each element in the list,
            the keys should be the layer name and the value should be corresponding gradients.
        argnums (int): An integer default to 0. Specifies which argument of func
            to compute layer-wise gradients with respect to.
        *x: List of arguments for `func`.
        regularization (List [float]): A list of floats default to 0.0. Specifies the
            regularization term to be added to the Hessian matrix in each layer. This is useful
            when the Hessian matrix is singular or ill-conditioned. The regularization
            term is `regularization * I`, where `I` is the identity matrix directly
            added to the Hessian matrix. The list is of length L, where L is the total number of 
            layers.

    Returns:
        A function that takes a list of tuples of Tensor `x` and a vector `v` and returns
        the approximated IHVP of the approximated Hessian of `func` and `v`.
'''

def ihvp_at_x_datainf(func: Callable,
                argnums: int = 0,
                *x,
                regularization: List [float] = [0.0]) -> Callable:
'''DataInf ihvp algorithm function (with fixed x).

    Standing for the inverse-hessian-vector product, returns a function that,
    when given vectors, computes the product of inverse-hessian and vector.

    DataInf assume the loss to be cross-entropy and thus derive a closed form
    ihvp without having to approximate the hessian.

    Args:
        func (Callable): A Python function that takes one or more arguments.
            Must return a list of dictionary of gradients. For each element in the list,
            the keys should be the layer name and the value should be corresponding gradients.
        argnums (int): An integer default to 0. Specifies which argument of func
            to compute layer-wise gradients with respect to.
        *x: List of arguments for `func`.
        regularization (List [float]): A list of floats default to 0.0. Specifies the
            regularization term to be added to the Hessian matrix in each layer. This is useful
            when the Hessian matrix is singular or ill-conditioned. The regularization
            term is `regularization * I`, where `I` is the identity matrix directly
            added to the Hessian matrix. The list is of length L, where L is the total number of 
            layers.

    Returns:
        A function that takes a vector `v` and returns the IHVP of the Hessian
        of `func` and `v`.
'''

Demonstration

def get_layer_wise_grads(inputs,...):
    criterion = nn.CrossEntropyLoss()
    grads = []
    for input in inputs:
        gradients = DefaultDict()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        model.zero_grad()
        loss.backward()
        for name, parameter in model.named_parameters():
            gradients[name] = parameter
        grads.append(gradients)
    return grads

lambda = np.randn(L)
ihvp = ihvp_datainf(get_layer_wise_grads, argnums=2,*args,regularization=lambda)
input_list = [
    (torch.randn(1), torch.randn(1), torch.randn(1),
    (torch.randn(1), torch.randn(1), torch.randn(1))
]
vec = torch.randn(5, 2)
ihvp(input_list, vec)
TheaperDeng commented 4 months ago

The API looks overall good to me. Please state the func's definition carefully since it could be the most difficult part for the users. Preferrably add an example in the docstring.

charles-pyj commented 4 months ago

Thanks!