Closed SeanZh30 closed 4 months ago
For train_lds
:
retrain_subset
and implement it in dattri/model_utils/retrain.py
ratio=0.5
model, dataloader, loss
in your parameters, but this means that you need to assume user's training loop is a classical one and set optimizer and other things by yourself. I think use a train_func
provided by user could work. You may ask user to have a unified train function (e.g., it takes a dataloader and return a trained model).:return: A tuple containing a list of trained models
. This could be very memory consuming. I think it could be better if you save the state dict and the training sample index on drive once a model is trained (be sure to name them in format so you can revisit and make pair of them. This would require you add a save_path
.model_count
to retrain_num
.subset_model_count
to expectation_sample
For cal_lds_gt
:
dattri/metrics/lds.py
. And I think calculate_lds_groundtruth
might be clearer for future users. This ground truth has nothing to do with the attribution score.models, test_loader, target_func
are the parameters you do need to use for the groundtruth of this function. I think you don't need data_indices, train_loader
.model
and model_path_list
. The first one is a nn.Module and the second one is a list of state dict path you saved in retrain_subset
.target_func
should not be a "A data attribution method". It should take test_loader
and model
and return a tensor with all the target output.For lds
:
dattri/metrics/lds.py
and I think it should be suitable to call it linear_datamodeling_score
.data_indices
for this function, because you need it to calculate the "counterfactual predictions".aggregation
parameter, so that users could choose raw
(return a list) or mean
(return a value).Please revise the issue again and AT me for review before you implement anything.
API Design Proposal for LDS Function Implementation
Context/Background
This proposal aims to design and implement an API for the Linear Datamodeling Score (LDS) method as described in the paper "TRAK: Attributing Model Behavior at Scale" (Section 2.1). The LDS method involves training multiple models on subsets of the data to evaluate conterfactual predictions. The conterfactual predictors is helpful in terms of predicting how model outputs change when the training set is modified.
Algorithm Description
The LDS method consists of the following steps:
API Design
Demonstration
Actionable Items (To-do List)
retrain_subset
function indattri/model_utils/retrain.py
.calculate_lds_groundtruth
function indattri/metrics/lds.py
.linear_datamodeling_score
function indattri/metrics/lds.py
..pt
files locally for future publication.