TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
28 stars 8 forks source link

[dattri.func] Implement ihvp_lissa and ihvp_at_x_lissa #30

Closed sx-liu closed 4 months ago

sx-liu commented 4 months ago

Description

1. Motivation and Context

To complete the proposal in the issue #22.

2. Summary of the change

Implemented the ihvp_lissa and ihvp_at_x_lissa functions under dattri\func\ihvp.py.
Implemented their unit tests under test\dattri\func\test_ihvp.py.

3. What tests have been added/updated for the change?

sx-liu commented 4 months ago

Hi Jiaqi, I found the torch.vmap function a little bit tricky. It turns out that the return type of the function passed through vmap can only be tensors or list, tuple of tensors, which cannot be functions.

In the first loop, we need to pre-compile a list of hvp functions, so it does not make sense to appply vmap directly here. I am wondering whether there exist any alternative functions that serve the similar purpose in pytorch?

Nevertheless, I have updated the second loop with vmap, which I think will definitely be more efficient.

jiaqima commented 4 months ago

Hi Jiaqi, I found the torch.vmap function a little bit tricky. It turns out that the return type of the function passed through vmap can only be tensors or list, tuple of tensors, which cannot be functions.

In the first loop, we need to pre-compile a list of hvp functions, so it does not make sense to appply vmap directly here. I am wondering whether there exist any alternative functions that serve the similar purpose in pytorch?

Nevertheless, I have updated the second loop with vmap, which I think will definitely be more efficient.

How about this: instead of a list of functions with each returning a tensor for a single data point, define a single function that return a list of tensors for all data points? And you could apply vmap inside this single function.

sx-liu commented 4 months ago

Hi Jiaqi, I found the torch.vmap function a little bit tricky. It turns out that the return type of the function passed through vmap can only be tensors or list, tuple of tensors, which cannot be functions. In the first loop, we need to pre-compile a list of hvp functions, so it does not make sense to appply vmap directly here. I am wondering whether there exist any alternative functions that serve the similar purpose in pytorch? Nevertheless, I have updated the second loop with vmap, which I think will definitely be more efficient.

How about this: instead of a list of functions with each returning a tensor for a single data point, define a single function that return a list of tensors for all data points? And you could apply vmap inside this single function.

But I wonder how could we fit one such function? Because the hvps are calculated sequentially for each data point.

Or do you mean our hvp function will calculate the hvp for all data points for each single step, and then we just randomly pick one as our update?

jiaqima commented 4 months ago

Hi Jiaqi, I found the torch.vmap function a little bit tricky. It turns out that the return type of the function passed through vmap can only be tensors or list, tuple of tensors, which cannot be functions. In the first loop, we need to pre-compile a list of hvp functions, so it does not make sense to appply vmap directly here. I am wondering whether there exist any alternative functions that serve the similar purpose in pytorch? Nevertheless, I have updated the second loop with vmap, which I think will definitely be more efficient.

How about this: instead of a list of functions with each returning a tensor for a single data point, define a single function that return a list of tensors for all data points? And you could apply vmap inside this single function.

But I wonder how could we fit one such function? Because the hvps are calculated sequentially for each data point.

Or do you mean our hvp function will calculate the hvp for all data points for each single step, and then we just randomly pick one as our update?

I'm thinking about the following.

Say we want to parallelize this double loop:

results = torch.zeros(size=(len(x_list), len(y_list)))
for i, x in enumerate(x_list):
    for j, y in enumerate(y_list):
        results[i, j] = func(x, y)

Can we do this?

_func_at_batch_x_func = vmap(lambda y: vmap(lambda x: func(x, y))(x_list))

In this case, you may want to use hvp instead of hvp_at_x to get func.

jiaqima commented 4 months ago

At the minimum, on top of your current implementation, you should put the following snippet outside def _ihvp_at_x_lissa_func(v: Tensor) -> Tensor:

        hvp_func_list = [
            hvp_at_x(func, x=data_point, argnums=argnums, mode=mode)
            for data_point in input_list
        ]

Note that the snippet above does not depend on v at all. So this shouldn't be done repeatedly for different v.

I'm not entirely sure which of this for-loop implementation and the double-vmap implementation I suggested earlier is more efficient.

If there are a lot of batches of v and only one batch of x in the calculating, then this for-loop implementation may be more efficient as the hvp_func_list will be compiled only once. This is because, while the double-vmap avoids the for loop over the batch of x, it doesn't pre-compile for x.

But if there are multiple batches of x, then the double-vmap implementation may be more efficient.

Furthermore, if hvp is called in rev-fwd mode, I think the double-vmap implementation will always be more efficient and there is no pre-compilation in this mode anyway.

@TheaperDeng please feel free to chime in.

jiaqima commented 4 months ago

I think the case for ihvp without "at_x" is clearer. The double-vmap implementation should be more efficient.

sx-liu commented 4 months ago

I feel like the vmap function does not completely support parallelizing our LiSSA function. That was a technical issue instead of a theoretical one.

The first problem is that vmap does not fully support "randomness" for objects other than tensor. Let's say we try to change a double loop into a nested vmap like this

     def _lissa_inner_loop(vec, sampled_hvp_func_list):
            # sampled_indices = random.sample(
            #     list(range(batch_size)),
            #     recursion_depth,
            # )
            sampled_indices = torch.randperm(len(input_list))[:recursion_depth]
            sampled_hvp_func_list = [hvp_func_list[idx] for idx in sampled_indices]

            curr_estimate = vec.detach().clone()  # No gradient on v
            for hvp_func in sampled_hvp_func_list:
                hvp = hvp_func(curr_estimate)
                curr_estimate = (vec
                                 + (1 - damping) * curr_estimate
                                 - hvp / scaling)

            return curr_estimate / scaling

      def _lissa_outer_loop(vec: torch.Tensor) -> torch.Tensor:
          lissa_inner_map = torch.vmap(_lissa_inner_loop,
                                       randomness="different")

          batch_sampled_indices = [
              random.sample(
                  list(range(batch_size)),
                  recursion_depth,
              )
              for _ in range(num_repeat)
          ]
          batch_sampled_indices = torch.tensor(batch_sampled_indices)
          print(batch_sampled_indices.shape)

          ihvp_estimations = lissa_inner_map(vec, batch_sampled_indices)
          return torch.mean(ihvp_estimations, dim=0)

      v = v[:, None, ...]  # (N, 1, ...)
      repeat_dims = (1, num_repeat) + (1,) * (v.dim() - 2)
      v = v.repeat(*repeat_dims) # (N, num_repeat, ...)
      return torch.vmap(_lissa_outer_loop, randomness="different")(v)

This code snippet essentially does something equivalent to the double loop implementation. The outer loop is a loop against a batch of v's and the inner loop is a loop against num_repeat. However, running this code gives the following error
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report..

The reason for this error is that we are trying to subscript a list with the tensor sampled_indices, which for some reason, is not supported in vmap currently. You can refer to a post this Feb. https://github.com/pytorch/pytorch/issues/42368#issuecomment-1937427695.

The commented out part shows another attempt to use some third-party random library for subscripting, but it turns out that will completely ruin randomness because this part of code will only run once in vmap implementation.

The second problem is that I am still not sure how to further parallelize the final loop against data points x. What we are doing is actually calculate the ihvp iteratively by the formula

$Hj^{-1}v = v + (I - \nabla\theta^2 L(z_{sj}, \hat \theta))H{j-1}^{-1}v$

How could we calculate $Hj^{-1}v$ without knowing $H{j - 1}^{-1}v$? I think what we could do at maximum is just pre-compile in parallel.

I am not sure whether I am understanding your previous suggestions on double-vmap correctly. Do you have any other thoughts on the resolving the current bugs? Thanks!

jiaqima commented 4 months ago

Thanks for the update. I haven't thought about this level of details yet and it may indeed be infeasible to parallelize the whole thing with vmap.

@TheaperDeng is going to give it a try on parallelizing the CG algorithm. Maybe we can wait and see if he will have some success there.

sx-liu commented 4 months ago

The updated LiSSA is ready for review. One thing to mention is that for the batch sampling I used the same strategy (but different implementation) as the one in https://github.com/nimarb/pytorch_influence_functions/blob/master/pytorch_influence_functions/influence_function.py, where a dataloader is used for random shuffling

for x, t in z_loader:
      if gpu >= 0:
          x, t = x.cuda(), t.cuda()
      y = model(x)
      loss = calc_loss(y, t)
      params = [ p for p in model.parameters() if p.requires_grad ]
      hv = hvp(loss, params, h_estimate)
      # Recursively caclulate h_estimate
      h_estimate = [
          _v + (1 - damp) * _h_e - _hv / scale
          for _v, _h_e, _hv in zip(v, h_estimate, hv)]
      break

We don't incorporate a dataloader in our case. The reason is that on second thought, I feel like it's a little bit redundant to define a dataloader class just for LiSSA, and the current implementation inherit the input format naturally from the previous ihvp functions.