TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.func] update cg ihvp to support vmap #61

Closed TheaperDeng closed 3 months ago

TheaperDeng commented 3 months ago

Description

1. Motivation and Context

Previously, we carried out a for-loop on the vector's first dim, which is inefficient.

Now we use vmap to parallel this process.

2. Summary of the change

  1. Change the cgihvp implementation to vmap rather than a for-loop
  2. TODO: vmap does not support data-control follow, https://github.com/pytorch/functorch/issues/257 . Though a torch.cond is provided, but it's now in prototype so I will leave it as a TODO. This means that we don't have a auto break condition, users need to set the max_iter carefully.
  3. change the max_iter to a smaller number since most of the time it does not need a large iteration

3. What tests have been added/updated for the change?

TheaperDeng commented 3 months ago

I'm wondering if the argument x in ihvp_at_x_cg could be batched? i.e., will there be a use case where we would need a list of hvp_at_x_func as in PR #30?

Thinking about the influence function use case, it seems that we won't need it to be batched. In fact, the current implementation of PR #30 seems to be incorrect? Since the Hessian in the hvp_at_x_func should be the Hessian of the loss over all the training data.

I think the list of x is a special need for LiSSA algorithm, it's not needed by the CG algorithm.

i.e.,

For CG, the hessian is calculated of the loss over all training data For LiSSA, the hessian is calculated of the loss over batchs of training data (each batch is an iteration)

jiaqima commented 3 months ago

I'm wondering if the argument x in ihvp_at_x_cg could be batched? i.e., will there be a use case where we would need a list of hvp_at_x_func as in PR #30? Thinking about the influence function use case, it seems that we won't need it to be batched. In fact, the current implementation of PR #30 seems to be incorrect? Since the Hessian in the hvp_at_x_func should be the Hessian of the loss over all the training data.

I think the list of x is a special need for LiSSA algorithm, it's not needed by the CG algorithm.

i.e.,

For CG, the hessian is calculated of the loss over all training data For LiSSA, the hessian is calculated of the loss over batchs of training data (each batch is an iteration)

I see. In this case, I think there won't be much gain for having the "_at_x" version of LiSSA?

jiaqima commented 3 months ago

In addition, it could be a separate PR, but I think the current ihvp_cg could also be made more efficient by calling hvp once instead of calling ihvp_at_x_cg?

TheaperDeng commented 3 months ago

I'm wondering if the argument x in ihvp_at_x_cg could be batched? i.e., will there be a use case where we would need a list of hvp_at_x_func as in PR #30? Thinking about the influence function use case, it seems that we won't need it to be batched. In fact, the current implementation of PR #30 seems to be incorrect? Since the Hessian in the hvp_at_x_func should be the Hessian of the loss over all the training data.

I think the list of x is a special need for LiSSA algorithm, it's not needed by the CG algorithm. i.e., For CG, the hessian is calculated of the loss over all training data For LiSSA, the hessian is calculated of the loss over batchs of training data (each batch is an iteration)

I see. In this case, I think there won't be much gain for having the "_at_x" version of LiSSA?

You can still generate a list of hvp functions over the batchs of x right (compared to the non-"_at_x" version)? That may takes some time and you can use the list of hvp functions later in the LiSSA core algorithm

jiaqima commented 3 months ago

I'm wondering if the argument x in ihvp_at_x_cg could be batched? i.e., will there be a use case where we would need a list of hvp_at_x_func as in PR #30? Thinking about the influence function use case, it seems that we won't need it to be batched. In fact, the current implementation of PR #30 seems to be incorrect? Since the Hessian in the hvp_at_x_func should be the Hessian of the loss over all the training data.

I think the list of x is a special need for LiSSA algorithm, it's not needed by the CG algorithm. i.e., For CG, the hessian is calculated of the loss over all training data For LiSSA, the hessian is calculated of the loss over batchs of training data (each batch is an iteration)

I see. In this case, I think there won't be much gain for having the "_at_x" version of LiSSA?

You can still generate a list of hvp functions over the batchs of x right (compared to the non-"_at_x" version)? That may takes some time and you can use the list of hvp functions later in the LiSSA core algorithm

Yeah, that makes sense.

TheaperDeng commented 3 months ago

In addition, it could be a separate PR, but I think the current ihvp_cg could also be made more efficient by calling hvp once instead of calling ihvp_at_x_cg?

Yeah, I can first generate a hvp function before go into the CG core algorithm. Though the time saved might be trivial. I can do it in next PR.