TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.func.ihvp] API Design Proposal for EK-FAC ihvp #76

Closed sx-liu closed 2 months ago

sx-liu commented 3 months ago

Background

Design the API for Eigenvalue-Corrected Kronecker-Factored Approximate Curvature (EK-FAC) ihvp algorithm. The core of EK-FAC ihvp algorithm is to estimate the inverse vector product of Fisher Information Matrix, or equivalent the Gauss-Newton Hessian.

Algorithm description

EK-FAC ihvp can be done layer-wise, using the formula:

$$ (G + \lambda I)^{-1}v = vec(Q_S^T[(Q_S \bar V Q_A^T) \oslash unvec(diag^{-1}(\Lambda + \lambda I))] Q_A ), $$

where $\bar V$ is the vector v reshaped to match the weight matrix $W$. All other variables are associated with $a{\ell - 1}$ and $D{S_\ell}$ (layer input and hidden state before the activation), which both can be obtained through some proper caching of a forward and backward pass.

This formula is only applicable to MLP. As for an approximation for LLMs, only influence on MLP's will be considered.

API Design

The main function for EK-FAC ihvp calculation:

def ihvp_at_x_ekfac(func: Callable,
                    *x,
                    in_dims: Optional[Tuple] = None,
                    batch_size: int = 1,
                    max_iter: Optional[int] = None,
                    mlp_cache: Union[MLPCache, List[MLPCache]],
                    damping: float = 0.0) -> Callable:
    """IHVP via EK-FAC algorithm.

    Standing for the inverse-hessian-vector product, returns a function that,
    when given vectors, computes the product of inverse-hessian and vector.

    EK-FAC algorithm provides layer-wise approximation for the ihvp function .
    The samples are estimated based on Gauss-Newton Hessian.

    Args:
        func (Callable): A Python function that takes one or more arguments.
            Must return a Tensor of shape (batch_size * t,), where t can be an 
            arbitrary integer. The hessian will be estimated on this function.
        *x: List of arguments for `func`.
        in_dims (Tuple, optional): A tuple with the same shape as *x, indicating
            which dimension should be considered as batch size dimension. Take the
            first dimension as batch size dimension by default.
        batch_size (int): An integer default to 1, indicating the batch size used for
            estimating the covariance matrices and lambdas.
        max_iter (int, optional): An integer indicating the maximum number of 
            batches that will be used for estimating the the covariance matrices and 
            lambdas.
        mlp_cache (Union[MLPCache, List[MLPCache]]): A single or list of registered
            caches, used to record the input and hidden vectors as well as their
            relevant gradients during the forward and backward calls of `func`.
        damping: Damping factor used for non-convexity in EK-FAC ihvp calculation.

    Returns:
        A function that takes  a tuple of Tensor `x` and a single or list of 
        vector `v` and returns the IHVP of the Hessian of `func` and `v`.
    """
    ...

For caching, we may also need to borrow some basic functions and classes from https://github.com/JacksonWuxs/UsableXAI_LLM/blob/b89dbf06d1b36cb2c2538feb0388098ef9a738d4/libs/core/hooks.py

Definition for cache:

@dataclass
class MLPCache:
    input_hidden_pairs: List[Tuple[torch.Tensor]] = []

    def collect_states(self):
        for each in self.input_hidden_pairs:
            assert isinstance(each[0], torch.Tensor) and isinstance(each[1], torch.Tensor)
        return self.input_hidden_pairs

    def named_parameters(self):
        ...

    def clear(self):
        ...

    def zero_grad(self):
        ...

    def retain_grad(self):
        ...

    def check_type(self):
       ...

Decorator for caching in forward pass of MLP modules:

KEY = "__cache"
def manual(forward_func):
        """manually rewrite the forward function to collect variables you are interested in"""
        @wraps(forward_func)
        def cached_forward(self, *args, **kwrds):
            if not hasattr(self, KEY):
                return forward_func(self, *args, **kwrds)
            cache = getattr(self, KEY)
            cache.clear()
            outputs = forward_func(self, *args, **kwrds)
            cache.check_type()
            cache.retain_grad()
            return outputs
        return cached_forward

Demonstration

The user may need to redefine the forward method of the MLP layers for caching. Use dattri.benchmark.models.mlp.MLPMnist as an example,

@manual
def custom_forward_method(self, hidden_states):
    if not hasattr(self, KEY):
        # Normal forward pass
        hidden_states = hidden_states.view(-1, 28*28)
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.dropout1(torch.relu(hidden_states))
        hidden_states = self.fc2(hidden_states)
        hidden_states = self.dropout2(torch.relu(hidden_states))
        hidden_states = self.fc3(hidden_states)
        return hidden_states
    cache = getattr(self, KEY)
    x1 = hidden_states.view(-1, 28*28)
    y1 = self.fc1(x1)
    cache.input_hidden_pairs.append((x1, y1))
    x2 = self.dropout1(torch.relu(y1))
    y2 = self.fc2(x2)
    cache.input_hidden_pairs.append((x2, y2))
    outputs = self.fc3(self.dropout2(torch.relu(y2)))
    return outputs

MLPMnist.forward = custom_forward_method

And also register the cache for the MLP

mlp = MLPMnist()
cache = MLPCache()
setattr(mlp, KEY, cache)

    @flatten_func(model, param_num=0)
    def f(params):
        ...

ihvp_at_x_ekfac(f, flatten_params(model_params), cache)