azshue / TPT

Test-time Prompt Tuning (TPT) for zero-shot generalization in vision-language models (NeurIPS 2022))
https://azshue.github.io/TPT/
MIT License
145 stars 17 forks source link

Question about implementation of log softmax #5

Closed linzhiqiu closed 1 year ago

linzhiqiu commented 1 year ago

I am curious why you used

logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) instead of logits = outputs.log_softmax(dim=1) [N, 1000]

and similarly avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) instead of avg_logits = logits.mean(0) [1, 1000]

Thanks!

azshue commented 1 year ago

Hi @linzhiqiu,

Thank you for your interest in our work and implementation.

For calculating the averaged entropy, we follow the implementation of MEMO. Using logsumexp is a trick for numerical stability, but it could be that PyTorch's implementation of log_softmax has accounted for this issue, so it may be fine using either. (I don't recall if I've had problems using the alternative implementation. It's likely that I haven't used the alternatives.)

linzhiqiu commented 1 year ago

Okay that makes a lot of sense -- I recently came across this issue in my other project and was worried that torch's log_softmax doesn't use this cool trick.

Thanks! Very impressive work and we are considering following up!