TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
30 stars 9 forks source link

[dattri.algorithm] Better `distance_func` for KNN data shapley #124

Open sx-liu opened 2 months ago

sx-liu commented 2 months ago

For current KNN data shapley attributor, the default distance_func is irrelevant to the model. One potential improvement for the distance function is to pass the model through the task argument, so that the distance could be calculated in the model embedding space.

def distance_func(batch_x, batch_y, task=None):
    ...

distance_func could visit the task and extract whatever they want.