Closed KurisuTheAmadeus closed 4 months ago
For train_loo
:
retrain_loo
and implement it in dattri/model_utils/retrain.py
.full_model, train_dataloader, loss_func
in your parameters, but this means that you need to assume user's training loop is a classical one and set optimizer and other things by yourself. I think use a train_func
provided by user could work. You may ask user to have a unified train function (e.g., it takes a dataloader and return a trained model).:return: A tuple containing a list of trained models
. This could be very memory consuming. I think it could be better if you save the state dict and the training sample index on drive once a model is trained (be sure to name them in format so you can revisit and make pair of them (e.g., model_remove_index_0.pt
, model_remove_index_1
). This would require you add a save_path
.require_seed
and seed
. Just leave one of them. truncate : int = None
. It means that only train the first truncate
data sample's loo model.For cal_loo_gt
, cal_pbrf_gt
:
dattri/metrics/loo.py
. And I think calculate_loo_groundtruth
and calculate_pbrf_groundtruth
might be clearer for future users. This ground truth has nothing to do with the attribution score.full_model
-> model
, loss_func
-> target_func
. Please talk with @SeanZh30 and @jackhaohuang about this function's design, because it's very much like the calculate_lds_groundtruth
.Just for clarification, regrading train_loo
, we are expecting a training function provided by the user and the actual training is done in that function and what we do in train_loo
is just calling the train_func
and saving the state_dict
right?
Just for clarification, regrading
train_loo
, we are expecting a training function provided by the user and the actual training is done in that function and what we do intrain_loo
is just calling thetrain_func
and saving thestate_dict
right?
Yes, so that we don't need to "assume" what does the training loop looks like. Different user may want different training loop. We just ask them to have a function that has the signature (input and output) defined by us.
Thanks! Also we would like to know how we should verify the pbrf function we wrote? The original paper does provide an example but it is done on a different model than ours.
Context
The goal of this part of API is to support LOO-retraining and PBRF-retraining (https://arxiv.org/pdf/2209.05364.pdf). We will support: a. output the models of LOO-retraining b. Directly output the LOO score for each test example on a subset of training points (It can be all training points) c. output the models of PBRF-retraining d. Directly output the PBRF score for each test example on a subset of training points (It can be all training points)
Algorithm design
API Design
Usage Examples
To-do List
train_loo
andcal_loo_gt
.train_pbrf
andcal_pbrf_gt
.