TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
28 stars 8 forks source link

fix some missing parts in mnist_lr retrain func #44

Closed TheaperDeng closed 4 months ago

TheaperDeng commented 4 months ago

Description

1. Motivation and Context

This PR is used to fix some minor issue in mnist + lr retrain/loss calculation functions.

2. Summary of the change

  1. Make the training epoch to 20 epochs, this will bring 90% test accuracy (previous 1 epoch seems too small)
  2. add a device parameter to the functions, so that users could choose which device to run the retraining/evaluation.

3. What tests have been added/updated for the change?