TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.func.random_projection] Format/naming/bugs fixing in related files #63

Closed tingwl0122 closed 3 months ago

tingwl0122 commented 3 months ago

Description

This PR fixed some slight issues in the existing dattri/func/random_projection.py and related test files.

1. Motivation and Context

The original dattri/func/random_projection/py looks similar to the TRAK version, which depends on the context of "projecting gradient matrix". Although this is the typical usage, it is still a general projecting function. So we get rid of names like grad_dim and model_id.

In addition to this, migrating some helper functions out to func/utils.py for clarity and fix some coding style issues.

2. Summary of the change

3. What tests have been added/updated for the change?

tingwl0122 commented 3 months ago

Hi @jiaqima. One slight thing for discussion: In https://github.com/TRAIS-Lab/dattri/pull/20#pullrequestreview-2018413915, we said that we may want to remove the _vectorize function in utils.py by flatten_params.

However, (1) the current flatten_params depends on _vectorize and (2) my usage of vectorize is different from what flatten_params did.

So I don't think we need to deprecate the usage of _vectorize?

tingwl0122 commented 3 months ago

@TheaperDeng, I guess this can also be merged if you are comfortable with the notation change.