state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.12k stars 1.11k forks source link

Speculative Decoding with Mamba 1 #391

Closed adityakotha03 closed 4 months ago

adityakotha03 commented 4 months ago

Hi, I am trying to implement the speculative decoding from Accelerating Large Language Model Decoding with Speculative Sampling, and below is the code snippet:

from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device="cuda", dtype=torch.float16)

device="cuda"
# Prepare initial input
input_text = "Mamba is a type of"
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].to(device)

# Generate the output and move it back to the CPU for decoding
out = model.generate(input_ids, max_length=input_ids.shape[-1]+10, cg=True)
print(out)
print(tokenizer.batch_decode(out.cpu()))

from mamba_ssm.utils.generation import InferenceParams
from mamba_ssm import Mamba

def get_distribution(logits, temperature):
    probs = torch.softmax(logits / (temperature + 1e-10), dim=-1)
    return probs

def sample(logits, temperature):
    probs = get_distribution(logits, temperature)
    return torch.multinomial(probs, num_samples=1)[0]

@torch.inference_mode()
def speculative_sampling(target_model, draft_model, input_ids, target_len, lookahead=4, temperature=1.0, debug=False, repetition_penalty=1.0, top_k=1, top_p=0.0, min_p=0.0):
    assert input_ids.shape[0] == 1, 'Batch size should be 1'

    n = input_ids.shape[-1]
    fin_prompt_ids = input_ids.detach().clone()

    while n < target_len:
        n_orig = n
        N = fin_prompt_ids.shape[-1]
        outs = draft_model.generate(input_ids=fin_prompt_ids, max_length=(N + lookahead), cg=True, output_scores=True, return_dict_in_generate=True, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, min_p=min_p)
        draft_outputs, draft_logits = outs.sequences, outs.scores

        # Ensure device compatibility
        draft_logits = torch.cat([logits.unsqueeze(1) for logits in draft_logits], dim=1).to(device)

        if debug:
            continuation = tokenizer.decode(draft_outputs[0, n_orig:], skip_special_tokens=True)
            print(f"Possible continuations: {continuation}")
            print(f"Possible continuations (token ids): {draft_outputs[0, n_orig:]}")

        infer_params = InferenceParams(max_batch_size=1, max_seqlen=draft_outputs.shape[1])
        y2 = target_model(draft_outputs.to(device), inference_params=infer_params)
        target_logits = y2.logits

        target_model_distribution = get_distribution(target_logits[:, -lookahead:], temperature)
        draft_model_distribution = get_distribution(draft_logits, temperature)

        accepted_flag = 1

        for t in range(lookahead):
            numerator = target_model_distribution[:, t, draft_outputs[0, N + t]]
            denominator = draft_model_distribution[:, t, draft_outputs[0, N + t]]
            ratio = numerator / denominator
            uniform_distribution = torch.rand_like(numerator, device=device)

            # Rejection Sampling
            if (uniform_distribution < torch.min(torch.ones_like(numerator), ratio)).any():
                fin_prompt_ids = torch.cat([fin_prompt_ids, draft_outputs[:, N + t].unsqueeze(dim=-1)], dim=-1)
                n += 1
            else:
                new_dist = (target_model_distribution[:, t, :] - draft_model_distribution[:, t, :])
                new_dist = torch.max(torch.zeros_like(new_dist), new_dist)
                new_dist = new_dist / new_dist.sum(dim=-1, keepdim=True)
                token_id = torch.multinomial(new_dist, num_samples=1)[0]
                fin_prompt_ids = torch.cat([fin_prompt_ids, token_id[None, ...]], dim=-1)
                accepted_flag = 0
                break

        if accepted_flag == 1:
            sample_token = sample(target_logits[:, -1, :], temperature=temperature)
            fin_prompt_ids = torch.cat([fin_prompt_ids, sample_token[None, ...]], dim=-1)

        if debug:
            final_continuation = tokenizer.decode(fin_prompt_ids[0, n_orig:], skip_special_tokens=True)
            print(f"Accepted continuations: {final_continuation}")
            print(f"Accepted continuations (token ids): {fin_prompt_ids[0, n_orig:]}")

        n += 1

    return fin_prompt_ids

# Example usage
target_len = input_ids.shape[-1] + 10
final_output = speculative_sampling(model, model, input_ids, target_len, temperature=1.0, debug=False)
print(final_output)
print(tokenizer.batch_decode(final_output.cpu()))

I'm using the same model for testing purposes but getting different outputs. When I tried debugging, I found out that the logits from the forward pass from infer_params differed from the generated ones. Any insights on what might be causing this would be appreciated. I am also attaching a google colab.

adityakotha03 commented 4 months ago

It should work by changing this line get_distribution(target_logits[:, -lookahead-1:-1], temperature)