hao-ai-lab / LookaheadDecoding

[ICML 2024] Break the Sequential Dependency of LLM Inference Using Lookahead Decoding
https://arxiv.org/abs/2402.02057
Apache License 2.0
1.15k stars 67 forks source link

[BUG Report] jacobi_greedy_search_multilevel function bug #56

Open yangbohust opened 7 months ago

yangbohust commented 7 months ago

When the first GUESS_SIZE elements of the correct list and the myguess list are consistent, it means that all guesses have been made. At this time, the last element of the correct list should also be the correct token, so it should be added to the hits list.

https://github.com/hao-ai-lab/LookaheadDecoding/blob/9d50de4a81d1b473bfce104ace18fbbbb6dc3255/lade/decoding.py#L1068C1-L1085C88

original code

hits = [first_guess] + [0] * (GUESS_SIZE - 1)
            #multi-level window is filled
            #match guess tokens 
            if guess_tokens is not None:
                guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
                for eg in range(len(guess_results) // GUESS_SIZE):
                    egx = eg * GUESS_SIZE
                    correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
                    myguess = guess_tokens[egx:egx + GUESS_SIZE]
                    gg = 0
                    for gg in range(len(myguess)):
                        if myguess[gg] != correct[gg]:
                            break 
                    if gg > max_hit:
                        max_hit = gg 
                        max_hit_idx = eg 
                        hits[:max_hit + 1] = correct[:max_hit + 1]
            #max_hit is the length of longest accepted sequence in verification branch 

Modified code

hits = [first_guess] + [0] * GUESS_SIZE
            #multi-level window is filled
            #match guess tokens 
            if guess_tokens is not None:
                guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
                for eg in range(len(guess_results) // GUESS_SIZE):
                    egx = eg * GUESS_SIZE
                    correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
                    myguess = guess_tokens[egx:egx + GUESS_SIZE]
                    gg = 0
                    while gg < len(myguess):
                        if myguess[gg] != correct[gg]:
                            break
                        gg += 1
                    if gg > max_hit:
                        max_hit = gg 
                        max_hit_idx = eg 
                        hits[:max_hit + 1] = correct[:max_hit + 1]
            #max_hit is the length of longest accepted sequence in verification branch