from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device="cuda", dtype=torch.float16)
device="cuda"
# Prepare initial input
input_text = "Mamba is a type of"
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].to(device)
# Generate the output and move it back to the CPU for decoding
out = model.generate(input_ids, max_length=input_ids.shape[-1]+10, cg=True)
print(out)
print(tokenizer.batch_decode(out.cpu()))
from mamba_ssm.utils.generation import InferenceParams
from mamba_ssm import Mamba
def get_distribution(logits, temperature):
probs = torch.softmax(logits / (temperature + 1e-10), dim=-1)
return probs
def sample(logits, temperature):
probs = get_distribution(logits, temperature)
return torch.multinomial(probs, num_samples=1)[0]
@torch.inference_mode()
def speculative_sampling(target_model, draft_model, input_ids, target_len, lookahead=4, temperature=1.0, debug=False, repetition_penalty=1.0, top_k=1, top_p=0.0, min_p=0.0):
assert input_ids.shape[0] == 1, 'Batch size should be 1'
n = input_ids.shape[-1]
fin_prompt_ids = input_ids.detach().clone()
while n < target_len:
n_orig = n
N = fin_prompt_ids.shape[-1]
outs = draft_model.generate(input_ids=fin_prompt_ids, max_length=(N + lookahead), cg=True, output_scores=True, return_dict_in_generate=True, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, min_p=min_p)
draft_outputs, draft_logits = outs.sequences, outs.scores
# Ensure device compatibility
draft_logits = torch.cat([logits.unsqueeze(1) for logits in draft_logits], dim=1).to(device)
if debug:
continuation = tokenizer.decode(draft_outputs[0, n_orig:], skip_special_tokens=True)
print(f"Possible continuations: {continuation}")
print(f"Possible continuations (token ids): {draft_outputs[0, n_orig:]}")
infer_params = InferenceParams(max_batch_size=1, max_seqlen=draft_outputs.shape[1])
y2 = target_model(draft_outputs.to(device), inference_params=infer_params)
target_logits = y2.logits
target_model_distribution = get_distribution(target_logits[:, -lookahead:], temperature)
draft_model_distribution = get_distribution(draft_logits, temperature)
accepted_flag = 1
for t in range(lookahead):
numerator = target_model_distribution[:, t, draft_outputs[0, N + t]]
denominator = draft_model_distribution[:, t, draft_outputs[0, N + t]]
ratio = numerator / denominator
uniform_distribution = torch.rand_like(numerator, device=device)
# Rejection Sampling
if (uniform_distribution < torch.min(torch.ones_like(numerator), ratio)).any():
fin_prompt_ids = torch.cat([fin_prompt_ids, draft_outputs[:, N + t].unsqueeze(dim=-1)], dim=-1)
n += 1
else:
new_dist = (target_model_distribution[:, t, :] - draft_model_distribution[:, t, :])
new_dist = torch.max(torch.zeros_like(new_dist), new_dist)
new_dist = new_dist / new_dist.sum(dim=-1, keepdim=True)
token_id = torch.multinomial(new_dist, num_samples=1)[0]
fin_prompt_ids = torch.cat([fin_prompt_ids, token_id[None, ...]], dim=-1)
accepted_flag = 0
break
if accepted_flag == 1:
sample_token = sample(target_logits[:, -1, :], temperature=temperature)
fin_prompt_ids = torch.cat([fin_prompt_ids, sample_token[None, ...]], dim=-1)
if debug:
final_continuation = tokenizer.decode(fin_prompt_ids[0, n_orig:], skip_special_tokens=True)
print(f"Accepted continuations: {final_continuation}")
print(f"Accepted continuations (token ids): {fin_prompt_ids[0, n_orig:]}")
n += 1
return fin_prompt_ids
# Example usage
target_len = input_ids.shape[-1] + 10
final_output = speculative_sampling(model, model, input_ids, target_len, temperature=1.0, debug=False)
print(final_output)
print(tokenizer.batch_decode(final_output.cpu()))
I'm using the same model for testing purposes but getting different outputs. When I tried debugging, I found out that the logits from the forward pass from infer_params differed from the generated ones. Any insights on what might be causing this would be appreciated. I am also attaching a google colab.
Hi, I am trying to implement the speculative decoding from Accelerating Large Language Model Decoding with Speculative Sampling, and below is the code snippet:
I'm using the same model for testing purposes but getting different outputs. When I tried debugging, I found out that the logits from the forward pass from infer_params differed from the generated ones. Any insights on what might be causing this would be appreciated. I am also attaching a google colab.