TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
27 stars 8 forks source link

API Design Proposal for LDS Function Implementation #9

Closed SeanZh30 closed 4 months ago

SeanZh30 commented 5 months ago

API Design Proposal for LDS Function Implementation

Context/Background

This proposal aims to design and implement an API for the Linear Datamodeling Score (LDS) method as described in the paper "TRAK: Attributing Model Behavior at Scale" (Section 2.1). The LDS method involves training multiple models on subsets of the data to evaluate conterfactual predictions. The conterfactual predictors is helpful in terms of predicting how model outputs change when the training set is modified.

Algorithm Description

The LDS method consists of the following steps:

  1. Select a 50% random subset of the training data.
  2. Train 5 models on this subset. Compute the expected outcome of these models.
  3. Repeat the above steps 100 times to generate 500 models and 100 expected outcomes.
  4. Compute the LDS for each example by using the Spearman rank correlation between the averaged model outputs computed in the previous step and the attribution-derived predictions of model outputs.

API Design

def retrain_subset(train_func, save_path, ratio=0.5, epochs=100, retrain_num=100, expectation_sample=5):
    """
    Retrains models on subsets of the data defined by a specified ratio, using a training function provided by the user. 
    Each trained model's state dict and the indices of the training samples are saved to the given path.

    :param train_func: A function provided by the user that takes a dataloader and returns a trained model.
    :param save_path: The path where trained models' state dicts and training sample indices are saved.
    :param ratio: The ratio of the training data to use for each subset.
    :param epochs: The number of epochs for training each model.
    :param retrain_num: The number of times the retraining process is repeated.
    :param expectation_sample: The number of models to train for each subset to compute the expected outcome.
    :return: None. (Trained models and their data indices are saved to the specified path.)
    """

def calculate_lds_groundtruth(model, model_path_list, test_loader, target_func):
    """
    Calculates the ground truth for the LDS method using a provided model architecture and a list of paths to model state dicts.

    :param model: A PyTorch nn.Module representing the model architecture.
    :param model_path_list: A list of paths to saved model state dicts.
    :param test_loader: A PyTorch DataLoader providing the dataset for evaluation.
    :param target_func: A function that takes a model and a DataLoader, returning a tensor of target outputs.
    :return: A list of tensors, each representing the ground truth calculated using the target function on the test dataset for each model.
    """

def linear_datamodeling_score(score, ground_truth, data_indices, aggregation='mean'):
    """
    Calculates the LDS score by comparing the model's predictions against the ground truth values, offering an option for raw or mean aggregation of results.

    :param score: A tensor or a list of tensors storing prediction scores that need evaluation.
    :param ground_truth: A tensor or a list of tensors storing the actual ground truth values for comparison.
    :param data_indices: Indices of the data used for training the models.
    :param aggregation: Determines the type of aggregation ('raw' for a list or 'mean' for a single value) for the Spearman rank correlation results.
    :return: A single value or a list indicating Spearman rank correlation, based on the specified aggregation type.
    :return: p-value
    """

Demonstration

dataloader = DataLoader(dataset, batch_size=32, shuffle=True) 
loss = nn.CrossEntropyLoss() # As example
trained_models, data_indices = train_lds(model, dataloader, loss, epochs=5, model_count=100, subset_model_count=5)
target_func = #Example target function could be influence function
test_loader = DataLoader(my_test_dataset, batch_size=32)
loo_ground_truth = cal_lds_gt(models, data_indices,  train_loader, test_loader, target_func)
score =   # Pt file storing conterfactual predictors score, e.g. Trak
lds_score = lds(score, loo_ground_truth)

Actionable Items (To-do List)

TheaperDeng commented 5 months ago

For train_lds:

  1. The name is somehow too narrow down to LDS metrics, we may need this function for other use case as well (e.g., for TRAK model retraining). I suggest change it to retrain_subset and implement it in dattri/model_utils/retrain.py
  2. We may have another parameter control the subset's size. e.g., ratio=0.5
  3. It's okay to make model, dataloader, loss in your parameters, but this means that you need to assume user's training loop is a classical one and set optimizer and other things by yourself. I think use a train_func provided by user could work. You may ask user to have a unified train function (e.g., it takes a dataloader and return a trained model).
  4. :return: A tuple containing a list of trained models. This could be very memory consuming. I think it could be better if you save the state dict and the training sample index on drive once a model is trained (be sure to name them in format so you can revisit and make pair of them. This would require you add a save_path.
  5. Change model_count to retrain_num.
  6. Change subset_model_count to expectation_sample
TheaperDeng commented 5 months ago

For cal_lds_gt:

  1. This method is for ground truth calculation for LDS ground truth. Please implement in dattri/metrics/lds.py. And I think calculate_lds_groundtruth might be clearer for future users. This ground truth has nothing to do with the attribution score.
  2. models, test_loader, target_func are the parameters you do need to use for the groundtruth of this function. I think you don't need data_indices, train_loader.
  3. for models, you may assume users could not load all the model list to memory (memory-consuming again). You may split it to 2 parameters. model and model_path_list. The first one is a nn.Module and the second one is a list of state dict path you saved in retrain_subset.
  4. target_func should not be a "A data attribution method". It should take test_loader and model and return a tensor with all the target output.
  5. The return value could be a list of tensor. List[target_func(model[0], test_loader), target_func(model[1], test_loader), ...]
TheaperDeng commented 5 months ago

For lds:

  1. This function is for the metric calculation. Please implement in dattri/metrics/lds.py and I think it should be suitable to call it linear_datamodeling_score.
  2. You need a data_indices for this function, because you need it to calculate the "counterfactual predictions".
  3. the output should include p-value
  4. add a aggregation parameter, so that users could choose raw (return a list) or mean (return a value).
TheaperDeng commented 5 months ago

Please revise the issue again and AT me for review before you implement anything.