alipay / PainlessInferenceAcceleration

Creative Commons Attribution 4.0 International
283 stars 18 forks source link

Counting how many forward passes/steps were done when using PAIN #14

Open jivanph opened 8 months ago

jivanph commented 8 months ago

I wanted to ask if there's a way to count how many forward passes/steps are done when using PAIN, to contrast it with standard decoding.

jivanph commented 8 months ago

On a different note, what are the parameters for the tree object? How many branches are made and how deep does the tree go?

zheyishine commented 8 months ago

You can count the steps with two methods, one is turning on the debug_lookahead, it will output debug info of each step and you can count the steps manually, the other is turning on return_dict_in_generate in model.generation method, the kwargs of outputs will output decoding summary, len(kwargs['dls']) is step count.

We use different parameters for different tasks. As methoned in the readme of out repo, we use decoding_length=128 (i.e., forward token count) and branch_length=32 (i.e., tree depth) for RAG tasks and decoding_length=64 and branch_length=8 for dolly and GSM8K tasks. We do not use the branch count parameter as we care more about factual token count in a forward pass rather than logical branches.

jivanph commented 8 months ago

Thank you so much for your response. This helped me greatly. If I understand correctly, if I want to count how many draft token in total were used when using PAIN, I could just compute sum(kwargs['dls'])

zheyishine commented 8 months ago

Should be sum(kwargs['dls'])-len(kwargs['dls']), because the decoding_length(i.e., dls) is compose of the next token and draft tokens, we should minus 1.