TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
30 stars 9 forks source link

Model training functions and loss calculation functions for mnist + lr #25

Closed TheaperDeng closed 6 months ago

TheaperDeng commented 6 months ago

Description

1. Motivation and Context

This PR is used to provide a standard model training function and loss calculation function for mnist + lr.

2. Summary of the change

Include 2 new functions(train_mnist_lr and loss_mnist_lr) in dattri.datasets.mnist. That can be used for the retraining utils and groundtruth utils.

3. What tests have been added/updated for the change?

TheaperDeng commented 6 months ago

I'm not sure if the name dattri/datasets would be a bit misleading given both models and datasets are implemented under it. Perhaps it makes more sense to call it dattri/benchmarks?

Yeah, seems benchmarks is now better than datasets.

jiaqima commented 6 months ago

Minor point: for consistency, I think we should either have dattri.metric, dattri.benchmark, dattri.algorithm, OR dattri.metrics, dattri.benchmarks, dattri.algorithms

TheaperDeng commented 6 months ago

Minor point: for consistency, I think we should either have dattri.metric, dattri.benchmark, dattri.algorithm, OR dattri.metrics, dattri.benchmarks, dattri.algorithms

Let me open an issue for this and merge this PR first