TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
27 stars 8 forks source link

Implemented retrain_loo and calculate_loo_groundtruth. #16

Closed charles-pyj closed 5 months ago

charles-pyj commented 5 months ago

1. Motivation and Context

retrain_loo

We would like to provide the users with a leave-one-out function that retrains and saves the models in a certain format for further uses. The training function is specified by the users and our function is in charge of splitting the data into leave-one-out subsets and call the user-specified training function to retrain and models and then save the models.

calculate_loo_groundtruth

After the user called retrain_loo and thus generates a directory of retrained LOO models, calculate_loo_groundtruth is capable of calculating the ground-truth values. The target function for groundtruth is user-specific, e.g. accuracy for binary classification or cross-entropy loss for softmax regression.

2. Summary of change

retrain_loo

We first manipulate the input dataloader and generate a list of lists where each of them represents a LOO subset of the full indices. We than subset the dataloader and call user-specified training function. We also keep track of the excluded indices, saved directories, etc. in the metadata which we will save at the root directory that user provides. We return None, as the function is only in charge of saving the models.

calculate_loo_groundtruth

Under the directory structure generated by retrain_loo, we first read in all the models and call user-specified target functions to generate ground truth values. We return the result as well as a list of excluded indices.

3. What tests have been added/updated for the change?

We will add test functions for logistic regression and CNN on MNIST dataset for testing. We will also deliver the retrained models.

xingjian-zhang commented 5 months ago

Hi @charles-pyj! Thanks for the effort. Please let me know (@ me) when this PR is ready for review.

charles-pyj commented 5 months ago

Hi @xingjian-zhang ! We believe apart from some formatting issues our PR is ready for review!

charles-pyj commented 5 months ago

Hi @xingjian-zhang thanks for the review! I have updated the code for more clarity. That being said, I still cannot pass Darglint due to DAR202. This is because retrain is supposed to return None but the keyword None is not recognized by Darglint. If I input nothing after the "Returns:" I cannot pass ruff. Are there any way to get around these checks?

xingjian-zhang commented 5 months ago

Hi @xingjian-zhang thanks for the review! I have updated the code for more clarity. That being said, I still cannot pass Darglint due to DAR202. This is because retrain is supposed to return None but the keyword None is not recognized by Darglint. If I input nothing after the "Returns:" I cannot pass ruff. Are there any way to get around these checks?

I guess you could try to remove the Return section in docstring since the function does not return anything?

xingjian-zhang commented 5 months ago

LGTM. If there is no other changes, we can merge it.