TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.func] Fixing Arnoldi's for-loop implementation #59

Closed tingwl0122 closed 3 months ago

tingwl0122 commented 3 months ago

Description

This PR fixes the for-loop implementation for batched-input.

1. Motivation and Context

We don't even need vmap for Arnoldi since it actually did a low-rank approximation on inverse Hessian. Regular matrix multiplication should work.

2. Summary of the change

3. What tests have been added/updated for the change?