voidism / DoLa

Official implementation for the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models"
https://arxiv.org/abs/2309.03883
419 stars 50 forks source link

Query regarding JS Divergence mean over batches. #13

Closed talha1503 closed 6 months ago

talha1503 commented 7 months ago

Hi! In dola.py on line 197, we calculate js_divs = js_divs.mean(-1) using which premature layer is selected. Does this mean that premature layer is the same for any tokens/logits present in the batch? Can you help me out with this please? @voidism

voidism commented 7 months ago

Hi @talha1503 ! The line 197 is in the for loop that will walk through all the decoding steps, so each token will have its own selected layer in different decoding steps. The .mean(-1) makes all examples share the same premature layer at each of the time steps. However, we always use batch size = 1 for these large LLaMA models, so it is not an issue.

talha1503 commented 6 months ago

@voidism Understood! Thank you so much for your response!