Open jivanph opened 8 months ago
On a different note, what are the parameters for the tree object? How many branches are made and how deep does the tree go?
You can count the steps with two methods, one is turning on the debug_lookahead, it will output debug info of each step and you can count the steps manually, the other is turning on return_dict_in_generate in model.generation method, the kwargs
of outputs will output decoding summary, len(kwargs['dls']) is step count.
We use different parameters for different tasks. As methoned in the readme
of out repo, we use decoding_length=128
(i.e., forward token count) and branch_length=32
(i.e., tree depth) for RAG tasks and decoding_length=64
and branch_length=8
for dolly and GSM8K tasks. We do not use the branch count
parameter as we care more about factual token count in a forward pass rather than logical branches.
Thank you so much for your response. This helped me greatly. If I understand correctly, if I want to count how many draft token in total were used when using PAIN, I could just compute sum(kwargs['dls'])
Should be sum(kwargs['dls'])-len(kwargs['dls']), because the decoding_length(i.e., dls
) is compose of the next token and draft tokens, we should minus 1.
I wanted to ask if there's a way to count how many forward passes/steps are done when using PAIN, to contrast it with standard decoding.