Open mmuneeburahman opened 11 months ago
According to my understanding, this is because of the derivative of loss w.r.t. softmax.
I think, In cs231n Notes, the definition of Cross-Entropy is not what is generally written or somewhat complete.
Categorical Cross-Entropy Loss:
However, it simplifies to equation above, as yj is 1 for only 1 class and yj is zero for other classes.
And it's derivative of L
is pj-1
Ref: https://math.stackexchange.com/questions/945871/derivative-of-softmax-loss-function
Just to clarify, for any sample $n$ when computing the loss, we only consider the predicted score of the class $c$ of interest, i.e., the class $c$ that is actually correct for that sample $n$:
$$L_n=\begin{cases}-\log\left(\frac{e^{\hat yi}}{\sum{i'} e^{\hat y_{i'}}}\right), & \text{when } i=c \qquad \\ 0, & \text{otherwise}\end{cases}$$
$$\nabla_{\hat y_i}L_n=\begin{cases}\left(\frac{e^{\hat yi}}{\sum{i'} e^{\hat y_{i'}}}\right)-1, & \text{when } i=c \\ \left(\frac{e^{\hat yi}}{\sum{i'} e^{\hat y_{i'}}}\right), & \text{otherwise}\end{cases}$$
The actual steps of arriving at the final derivative shown above are a bit more involved but you could easily solve that with pen and paper using the differentiation rules, like chain, logarithm, division etc., e.g., $\nabla_x \log(x)= \frac{1}{x}$ (we assume $\log=\text{ln}$).
Intuitively, that -1
accounts for $\hat{y}_c$ appearing in the enumerator in the original calculation of $L_n$, which otherwise does not happen for any other $\hat{y}_i$ with $i \ne c$.
Extra note: the full cross-entropy loss is calculated as the average negative log probability of the correct class across all your data points. For a single data point $n$, its negative log probability is just the "surprisal" of that event according to information theory.
Edit: for clarity, I should add that I use a different notation from the one in cs231n
notes:
cs231n
notes, $f_{y_i}$ means the predicted score for the class $y_i$ that is actually correct for sample $i$
Here is a partial derivative jacobian matrix for softmax: This simplifies to:
Didn't get this? Can someone explain? For reference see blog.