yangkevin2 / naacl-2021-fudge-controlled-generation

MIT License
95 stars 16 forks source link

Calculating the full logit for formality #1

Open YipingNUS opened 3 years ago

YipingNUS commented 3 years ago

@yangkevin2, really nice paper and code. I have a question regarding the calculation of the full logit in the predict_formality.py.

Seem like you're adding the raw logit top_logits with the condition_logits which is the log probability. Why it's so? I suppose we need to calculate log probability of top_logits before we can add the two together.

full_logits = top_logits + condition_lambda * condition_logits

I'm still trying to understand the code and I might likely be wrong. Thanks in advance!

yangkevin2 commented 3 years ago

Hi,

Mathematically, the logits are just log probabilities + an additive constant (since you log_softmax the logits to get log probabilities). You could calculate log probability of top_logits first but it shouldn't change the result, since the additive constant would just be "softmaxed away" later.