zjost / blog_code

A repo for holding example code
Apache License 2.0
149 stars 76 forks source link

Difficulty in understanding xent function #11

Closed utkarsh0902311047 closed 1 year ago

utkarsh0902311047 commented 1 year ago

Hi, Thank you so much for the detailed code. I would highly appreciate your explaining the xent function with np.arange and np.argmax. What is it actually doing?

zjost commented 1 year ago

Hello. You mean this?

def xent(pred, labels):
    return -np.log(pred)[np.arange(pred.shape[0]), np.argmax(labels, axis=1)]

Recall that cross entropy is $- \sum_c y_c \log p_c$, where $y_c$ is 1 if c is the true class, and zero otherwise. Likewise, $p_c$ is the prediction that the class is $c$. The term in the sum is zero except for the one case where $y_c = 1$, which is given by np.argmax(labels, axis=1). So we calculate the negative log likelihood, and then select the one entry where $y_c=1$. I don't know why I'm using np.arange here, because this is just selecting everything and could have just used :.