TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
27 stars 8 forks source link

API Design Proposal for LOO and PBRF #11

Closed KurisuTheAmadeus closed 4 months ago

KurisuTheAmadeus commented 5 months ago

Context

The goal of this part of API is to support LOO-retraining and PBRF-retraining (https://arxiv.org/pdf/2209.05364.pdf). We will support: a. output the models of LOO-retraining b. Directly output the LOO score for each test example on a subset of training points (It can be all training points) c. output the models of PBRF-retraining d. Directly output the PBRF score for each test example on a subset of training points (It can be all training points)

Algorithm design

  1. For a, the algorithm for output models of LOO-retraining is straightforward: we just train the LOO models with respect to each training point sequentially.
  2. For c, as the PBRF paper suggests, we will require a 'pretrained model' ($\theta_s$ in the PBRF) paper and then train $K$ steps with respect to the PBRF objective function.
  3. For both b and d, we will do the following with a loop of all training points: train one model, evaluate all the test points on it, store the result. Note that we don't train all models and then calculate the scores -- this is too memory intensive, and it can be avoided.
  4. We emphasize that the seed configuration is important -- otherwise the output will be random and cannot be easily reproduced.

API Design

def train_loo(full_model, loss_func, train_dataloader, require_seed=False, seed=None):
    '''
    Function for training a model using the Leave-One-Out (LOO) cross-validation method. This approach involves training the model multiple times, each time using all data points except one.

    :param full_model: A PyTorch nn.Module representing the complete model architecture to be trained. This model will be trained multiple times, each with a different subset of the training data.
    :param loss_func: A PyTorch loss function used to compute the error between the model predictions and the actual targets during each training iteration.
    :param train_dataloader: A PyTorch DataLoader providing the dataset for training. It should yield batches of data, where each batch contains input features and targets.
    :param require_seed: (optional) A boolean indicating whether to use a fixed seed for the random number generator to ensure reproducibility across the training iterations. Defaults to False.
    :param seed: (optional) An integer to set as the seed for the random number generator if require_seed is True. This parameter is ignored if require_seed is False.

    :return: A list of models trained using the LOO method. Each model in the list is trained on the dataset excluding one different data point, providing a comprehensive set of models to estimate the generalization error.
    '''

def cal_loo_gt(full_model, loss_func, train_dataloader, test_dataloader, require_seed=False, seed=None):
    '''
    Function for calculating the Leave-One-Out (LOO) influence of training points on test points.

    :param full_model: A PyTorch nn.Module representing the complete model architecture to be evaluated.
    :param loss_func: A PyTorch loss function used to compute the error between the model predictions and the actual targets during the LOO training and evaluation.
    :param train_dataloader: A PyTorch DataLoader providing the dataset for training, comprising batches of input features and targets. This dataset is used for the LOO training process.
    :param test_dataloader: A PyTorch DataLoader providing the dataset for testing, comprising batches of input features and targets. This is used to evaluate the performance of the model trained on each LOO subset.
    :param require_seed: (optional) A boolean indicating whether to use a fixed seed for the random number generator, ensuring reproducibility across the LOO training iterations. Defaults to False.
    :param seed: (optional) An integer to set as the seed for the random number generator if require_seed is True. This parameter is ignored if require_seed is False.

    :return: A tensor of dimensions (number of training points, number of test points), representing the LOO influence of training points on test points.
    '''

def train_pbrf(pretrained_model, loss_func, k, train_dataloader, require_seed=False, seed=None):
    '''
    Function for training models with respect to PBRF objective.
    :param pretrained_model: A PyTorch nn.Module representing the pretrained model architecture to be further trained.
    :param loss_func: A PyTorch loss function used to compute the error between the model predictions and the actual targets.
    :param k: The number of training steps to perform with respect to the PBRF objective, dictating the extent of training under this method.
    :param train_dataloader: A PyTorch DataLoader providing the dataset for training, comprising batches of input features and targets.
    :param require_seed: (optional) A boolean indicating whether to use a fixed seed for the random number generator, ensuring reproducibility. Defaults to False.
    :param seed: (optional) An integer to set as the seed for the random number generator if require_seed is True. This parameter is ignored if require_seed is False.

    :return: A list containing the trained model(s) following the PBRF training steps. 
     '''

def cal_pbrf_gt(pretrained_model, full_model, loss_func, train_dataloader, test_dataloader, K, require_seed=False, seed=None):
    '''
    Function for calculating the influence values using the (PBRF) method. 
    This involves applying the PBRF training steps to a pretrained model and then evaluating its performance on a test dataset, and then compare the loss difference.  

    :param pretrained_model: A PyTorch nn.Module representing the pretrained model architecture to be further trained.
   :param full_model: A PyTorch nn.Module representing the full model 
    :param loss_func: A PyTorch loss function used to compute the error between the model predictions and the actual targets during training and testing.
    :param train_dataloader: A PyTorch DataLoader providing the dataset for training, comprising batches of input features and targets.
    :param test_dataloader: A PyTorch DataLoader providing the dataset for testing, comprising batches of input features and targets. This is used to evaluate the model's performance after applying the PBRF method.
    :param K: The number of training steps to perform with respect to the PBRF objective during the training process.
    :param require_seed: (optional) A boolean indicating whether to use a fixed seed for the random number generator to ensure reproducibility. Defaults to False.
    :param seed: (optional) An integer to set as the seed for the random number generator if require_seed is True. This parameter is ignored if require_seed is False.

    :return: A tensor of dimensions (number of training points, number of test points), representing the scores calculated for each combination of training and test data points under the PBRF method.
    '''

Usage Examples

model = LogisticRegression()
full_model = train(model)
criterion = nn.CrossEntropyLoss()
train_loader = DataLoader(train_data_subset, batch_size=64)  
test_loader = DataLoader(test_data_subset, batch_size=64)  
loo_models =  train_loo(model, criterion, train_loader, require_seed=True, seed= 0)

pretrained_model = pre_train(model)
pbrf_score = cal_pbrf_gt(pertained_model, full_model, train_loader, test_loader, 10, require_seed=False, seed=None):

To-do List

TheaperDeng commented 5 months ago

For train_loo:

  1. We may need this function for other use case as well (e.g., for TRAK model retraining). I suggest change it to retrain_loo and implement it in dattri/model_utils/retrain.py.
  2. It's okay to make full_model, train_dataloader, loss_func in your parameters, but this means that you need to assume user's training loop is a classical one and set optimizer and other things by yourself. I think use a train_func provided by user could work. You may ask user to have a unified train function (e.g., it takes a dataloader and return a trained model).
  3. :return: A tuple containing a list of trained models. This could be very memory consuming. I think it could be better if you save the state dict and the training sample index on drive once a model is trained (be sure to name them in format so you can revisit and make pair of them (e.g., model_remove_index_0.pt, model_remove_index_1). This would require you add a save_path.
  4. No need to have both require_seed and seed. Just leave one of them.
  5. Add a parameter truncate : int = None. It means that only train the first truncate data sample's loo model.
TheaperDeng commented 5 months ago

For cal_loo_gt, cal_pbrf_gt:

  1. This method is for ground truth calculation for LOO ground truth. Please implement in dattri/metrics/loo.py. And I think calculate_loo_groundtruth and calculate_pbrf_groundtruth might be clearer for future users. This ground truth has nothing to do with the attribution score.
  2. full_model -> model, loss_func -> target_func. Please talk with @SeanZh30 and @jackhaohuang about this function's design, because it's very much like the calculate_lds_groundtruth.
charles-pyj commented 5 months ago

Just for clarification, regrading train_loo, we are expecting a training function provided by the user and the actual training is done in that function and what we do in train_loo is just calling the train_func and saving the state_dict right?

TheaperDeng commented 5 months ago

Just for clarification, regrading train_loo, we are expecting a training function provided by the user and the actual training is done in that function and what we do in train_loo is just calling the train_func and saving the state_dict right?

Yes, so that we don't need to "assume" what does the training loop looks like. Different user may want different training loop. We just ask them to have a function that has the signature (input and output) defined by us.

charles-pyj commented 5 months ago

Thanks! Also we would like to know how we should verify the pbrf function we wrote? The original paper does provide an example but it is done on a different model than ours.