Closed linzhiqiu closed 1 year ago
Hi @linzhiqiu,
Thank you for your interest in our work and implementation.
For calculating the averaged entropy, we follow the implementation of MEMO. Using logsumexp
is a trick for numerical stability, but it could be that PyTorch's implementation of log_softmax
has accounted for this issue, so it may be fine using either.
(I don't recall if I've had problems using the alternative implementation. It's likely that I haven't used the alternatives.)
Okay that makes a lot of sense -- I recently came across this issue in my other project and was worried that torch's log_softmax doesn't use this cool trick.
Thanks! Very impressive work and we are considering following up!
I am curious why you used
logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
instead oflogits = outputs.log_softmax(dim=1) [N, 1000]
and similarly
avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
instead ofavg_logits = logits.mean(0) [1, 1000]
Thanks!