Closed Jiaxin-Wen closed 1 year ago
Hi @Jiaxin-Wen,
good question! Indeed, we set the model output function as $\log\frac{p}{1-p}$. Note, however, that the logits are not p
here: p
is the softmax probability and the logits are the model outputs pre-softmax. In other words, suppose we have an input $z=(x,y)$; then if we have logits $g\in\mathbb{R}^{\text{number of classes}}$, the softmax probability $p\in\mathbb{R}$ for class $y$ is given by $p=\frac{\exp(g_y)}{\sum_j \exp(g_j)}$. Substituting in our model output function as $f=\log\frac{p}{1-p}$, we have
$$f = \log\left(\frac{{\exp(g_y)}/{\sum_j \exp(g_j)}}{1-{\exp(g_y)}/{\sum_j \exp(g_j)}}\right) = \log\left(\frac{{\exp(g_y)}/{\sum_j \exp(g_j)}}{(\sum_j \exp(g_j)-{\exp(g_y)})/{\sum_j \exp(g_j)}}\right) = \log\left(\frac{\exp(gy)}{\sum{j\neq y} \exp(g_j)}\right) = gy - \log\left(\sum{j\neq y} \exp(g_j)\right).$$
Now that we have $gy - \log\left(\sum{j\neq y} \exp(g_j)\right)$, $gy$ is just the correct logit, and $\log\left(\sum{j\neq y} \exp(g_j)\right)$ is the logsumexp
of the remaining logits.
Thus, we implement the get_output
function as:
https://github.com/MadryLab/trak/blob/76b13ca55f1a16a243a23aba312cf6aad57b84d0/trak/modelout_functions.py#L415-L427
I can see how the difference between $g$ and $p$ can be confusing, I'll amend the docs to reflect that :)
Feel free to re-open this if you have any further questions.
Thanks for your clarification!
Another small question, regarding the implementation of get_out_to_loss_grad
https://github.com/MadryLab/trak/blob/76b13ca55f1a16a243a23aba312cf6aad57b84d0/trak/modelout_functions.py#L444, why the return value is 1-p
instead of -(1-p)
Ah, good catch. We sweep that under the rug a bit: the margin computation in get_output also has a minus sign since L is actually given by -log(p); thus, the two negatives cancel out.
The output function for text classification is $\log\frac{p}{1-p}$. In the current implementation (TextClassificationModelOutput),
cloned_logits.logsumexp(dim=-1)
correctly calculates\log 1-p
. However, I wonder iflogits_correct
needs to be modified intotorch.log(logits_correct)
to calculatelog p
?