Closed talha1503 closed 6 months ago
Hi @talha1503 !
The line 197 is in the for loop that will walk through all the decoding steps, so each token will have its own selected layer in different decoding steps.
The .mean(-1)
makes all examples share the same premature layer at each of the time steps. However, we always use batch size = 1 for these large LLaMA models, so it is not an issue.
@voidism Understood! Thank you so much for your response!
Hi! In dola.py on line 197, we calculate
js_divs = js_divs.mean(-1)
using which premature layer is selected. Does this mean that premature layer is the same for any tokens/logits present in the batch? Can you help me out with this please? @voidism