MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
175 stars 24 forks source link

Regarding the implementation of `TextClassificationModelOutput` #51

Closed Jiaxin-Wen closed 11 months ago

Jiaxin-Wen commented 11 months ago

The output function for text classification is $\log\frac{p}{1-p}$. In the current implementation (TextClassificationModelOutput), cloned_logits.logsumexp(dim=-1) correctly calculates \log 1-p. However, I wonder if logits_correct needs to be modified into torch.log(logits_correct) to calculate log p?

kristian-georgiev commented 11 months ago

Hi @Jiaxin-Wen,

good question! Indeed, we set the model output function as $\log\frac{p}{1-p}$. Note, however, that the logits are not p here: p is the softmax probability and the logits are the model outputs pre-softmax. In other words, suppose we have an input $z=(x,y)$; then if we have logits $g\in\mathbb{R}^{\text{number of classes}}$, the softmax probability $p\in\mathbb{R}$ for class $y$ is given by $p=\frac{\exp(g_y)}{\sum_j \exp(g_j)}$. Substituting in our model output function as $f=\log\frac{p}{1-p}$, we have $$f = \log\left(\frac{{\exp(g_y)}/{\sum_j \exp(g_j)}}{1-{\exp(g_y)}/{\sum_j \exp(g_j)}}\right) = \log\left(\frac{{\exp(g_y)}/{\sum_j \exp(g_j)}}{(\sum_j \exp(g_j)-{\exp(g_y)})/{\sum_j \exp(g_j)}}\right) = \log\left(\frac{\exp(gy)}{\sum{j\neq y} \exp(g_j)}\right) = gy - \log\left(\sum{j\neq y} \exp(g_j)\right).$$ Now that we have $gy - \log\left(\sum{j\neq y} \exp(g_j)\right)$, $gy$ is just the correct logit, and $\log\left(\sum{j\neq y} \exp(g_j)\right)$ is the logsumexp of the remaining logits. Thus, we implement the get_output function as: https://github.com/MadryLab/trak/blob/76b13ca55f1a16a243a23aba312cf6aad57b84d0/trak/modelout_functions.py#L415-L427

I can see how the difference between $g$ and $p$ can be confusing, I'll amend the docs to reflect that :)

kristian-georgiev commented 11 months ago

Pushed:

image

Feel free to re-open this if you have any further questions.

Jiaxin-Wen commented 11 months ago

Thanks for your clarification! Another small question, regarding the implementation of get_out_to_loss_grad https://github.com/MadryLab/trak/blob/76b13ca55f1a16a243a23aba312cf6aad57b84d0/trak/modelout_functions.py#L444, why the return value is 1-p instead of -(1-p)

kristian-georgiev commented 11 months ago

Ah, good catch. We sweep that under the rug a bit: the margin computation in get_output also has a minus sign since L is actually given by -log(p); thus, the two negatives cancel out.