flexflow / FlexFlow

FlexFlow Serve: Low-Latency, High-Performance LLM Serving
https://flexflow.readthedocs.io
Apache License 2.0
1.7k stars 227 forks source link

Using Algorithm 2 in SpecInfer paper, I get wrong outputs. #1302

Open dutsc opened 8 months ago

dutsc commented 8 months ago

I'm having trouble learning the SpecInfer source code.

The pseudo code of the algorithm in the SpecInfer paper about verify model verifying the output of draft model is as follows: image

I implemented this pseudocode using python, but the output I got was not normal and it didn't seem to be the correct answer.

my prompt: Could you please give me some advice on programming?
my generated text: Without knowing what you want to develop I cannot help you very much. You have to study the specific topics you want to learn. I assume you want to learn programming thinking about programs. I'm very vague about that since I work more in scientific computing than in programming.
no probsdpocket weed and we got stoned now i'm tripping balls high as hell now that's WHO I

I set max_length=100, draft model inference step lookahead=4, verify model uses opt-6.7b, and draft model uses opt-125m.

I hope to solve the problem by referring to the SpecInfer source code, but this is very difficult for me. I guess that the part where the verify model verifies the output of the draft model is in the prepare_next_batch_init function and the traverse_verify_tree function in the request_manager.cc file, but I can't quite understand the contents.

Here is the above pseudocode implemented in python:

def verify_stochastic(root, llm_logits, temperature=0):
    torch.no_grad()
    V = []  # List to store verified tokens
    u = root  # Start with the root node
    # index_mapping stores the output results of SSMs and merges them into a tree
    index_mapping = {}
    index_mapping[root]=-1
    create_index_mapping(root, index_mapping)
    # while u is a non-leaf node do
    while len(u.children.values()) > 0:

        # H = child(u)
        H = list(u.children.values()) 

        # while H is not empty do
        while len(H) > 0:

            # s ∼ rand(H)   # rand select a node
            s_idx = random.randint(0, len(H) - 1)
            s = H[s_idx]

            # r ~ U(0, 1) # randint [0,1]
            r = random.random()

            # x_s = H[s] # token of the node
            x_s = s.token_logits_pair.token

            # if r ≤ P(x_s | u, LLM)/P(x_s | u, SSMs ) then
            ssmp_s = s.token_logits_pair.logits[:,x_s].item() + 1e-9
            llmp_s = llm_logits[:, index_mapping[s.parent]+1, x_s].item() + 1e-9
            print(f"ssmp: {ssmp_s}")
            print(f"llmp: {llmp_s}")
            if r <= llmp_s / ssmp_s:
                V.append(x_s)
                u = s
                break
            # else
            else:
                # P(x | u, LLM) := norm(max(0, P(x | u, LLM) − P(x | u, SSMs )))
                llmp = llm_logits[:, index_mapping[s.parent]+1, :]
                ssmp = s.token_logits_pair.logits[:,:]

                new_dist = (llmp - ssmp)
                new_dist = torch.max(torch.zeros_like(new_dist), new_dist)
                new_dist = new_dist / new_dist.sum(dim=-1, keepdim=True)
                llm_logits[:, index_mapping[s.parent]+1, :] = new_dist

                H.remove(s)
        # if H is empty then: 
        if len(H) == 0:
            break
    # xnext ∼ P(x | u, ΘLLM) 
    llmp = llm_logits[:, index_mapping[s.parent]+1, :]
    x_next = torch.multinomial(llmp, num_samples=1)  # rand sample from new_dist

    V.append(x_next)
    return V

Here is a description about TreeNode:

class TokenLogitsPair:
    def __init__(self, token, logits):
        self.token = token
        self.logits = logits
    def to(self, device, non_blocking=False):
        self.token = self.token.to(device, non_blocking=non_blocking)
        self.logits = self.logits.to(device, non_blocking=non_blocking)
        return self

class TreeNode:
    def __init__(self, token_logits_pair=None, parent=None):
        self.token_logits_pair = token_logits_pair
        self.parent = parent
        self.children = {}

I hope someone can help me.

dutsc commented 8 months ago

I checked the traverse_verify_tree() function under the /FlexFlow/src/runtime/request_manager.cc path and found that it only verifies whether the tokens are equal. Does this mean that the default implementation of specinfer is the VERIFYGREEDY function of Algorithm 2 in the paper?

Algorithm 2 in the paper: image

The meaning of this pseudocode is to find a path from the root to the leaf node in the token tree so that its token has the same result as the verify model.

traverse_verify_tree() code snippet:

for (int i = 0; i < outputSerializedTree.size(); i++) {
    auto input = inputSerializedTree.at(i);
    auto output = outputSerializedTree.at(i);

    if (i == 0) {
      verifiedTree.push_back(output); 
      new_committed_tokens.push_back(std::make_pair(
          input.second,
          committed_tokens.at(guid).at(i).second)); // <input_abs_depth,
                                                    // input_index_in_batch>
      // std::cout << committed_tokens.at(guid).at(i).first << ", "
      //           << committed_tokens.at(guid).at(i).second << std::endl;
      // std::cout << input.first << ", " << input.second << std::endl;

      assert(committed_tokens.at(guid).at(i).first == input.second);
      continue;
    }

    if (input.first == verifiedTree.back().first &&
        input.second == verifiedTree.back().second) {  //  input == verifiedTree.back()
      verifiedTree.push_back(output);
      new_committed_tokens.push_back(std::make_pair(
          input.second,
          committed_tokens.at(guid).at(i).second)); // <input_abs_depth,
                                                    // input_index_in_batch>
      assert(committed_tokens.at(guid).at(i).first == input.second);
    }
  }

traverse_verify_tree() only has about 100 lines in total. Except for the content in the picture, it is basically printing the log.

jiazhihao commented 8 months ago

The current implementation performs greedy decoding, and we are working on a PR for multi-step stochastic sampling and verification. Are the incorrect outputs generated using greedy decoding or stochastic?

dutsc commented 8 months ago

The incorrect outputs are generated with stochastic decoding according to Algorithm 2 in SpecInfer paper. When I use greedy verify from Algorithm 2, the same prompt produces the same result.

my prompt:please introduce Kobe Bryant, who played basketball in NBA.

SpecInfer outputs:

I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or

my implementation greedy verify outputs:

I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or