Closed kashif closed 2 weeks ago
Hi @kashif - thank you for your port!
Flash STU omits the autoregressive component from the original STU paper and relies solely on the spectral component. We found the autoregressive part to be sometimes helpful in terms of performance, but a bit too slow for our liking.
Here is an example of what an inference script for Flash STU could look like:
import tiktoken
import torch
from flash_stu import FlashSTU, FlashSTUConfig, get_spectral_filters
from safetensors import safe_open
tokenizer = tiktoken.get_encoding('o200k_base')
prompt = "Hi, my name is"
device = torch.device('cuda')
def generate_text(model, tokenizer, prompt, num_return_sequences=5, max_length=1024, device='cuda', temperature=1.0, top_k=50):
model.eval()
tokens = torch.tensor([tokenizer.encode(prompt, allowed_special={'<|endoftext|>'})], device=device)
tokens = tokens.repeat(num_return_sequences, 1)
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(1337)
eos_token_id = tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
with torch.no_grad():
for _ in range(max_length - tokens.size(1)):
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
logits = model(tokens)
logits = logits[:, -1, :] # Get logits for the last token
# Apply temperature scaling if temperature > 0
if temperature > 0:
logits = logits / temperature
probs = torch.nn.functional.softmax(logits, dim=-1) # Compute probabilities
# Top-K sampling: set all probabilities outside the top K to 0
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
ix = torch.multinomial(top_k_probs, 1, generator=sample_rng)
next_token = torch.gather(top_k_indices, -1, ix)
tokens = torch.cat((tokens, next_token), dim=1) # The autoregressive part!
# Break if EOS token is generated
if (next_token == eos_token_id).any():
break
generated_sequences = []
for i in range(num_return_sequences):
decoded = tokenizer.decode(tokens[i].tolist())
generated_sequences.append(decoded)
return generated_sequences
config = FlashSTUConfig()
phi = get_spectral_filters(
seq_len,
num_eigh,
use_hankel_L,
device,
torch.bfloat16,
)
model = FlashSTU(config, phi)
state_dict = {}
with safe_open('model.safetensors', framework="pt", device='cuda') as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
model.load_state_dict(state_dict)
model.to(device)
generated_texts = generate_text(model, tokenizer, prompt, num_return_sequences=5, max_length=1024)
for i, text in enumerate(generated_texts):
print(f"Sample {i + 1}: {text}\n")
This will generate until (1) the EOS token is produced or (2) the maximum sequence length is reached. Is this what you were looking for in terms of inference?
Thanks! Checking your way!
@windsornguyen as a slight-aside what are the key differences in the original STU implementation and the one in flash-STU in just the STU module?
@kashif
thanks!
Thanks for the repo, here is my pytorch port of the
STU
layer with an additionalstep
function that might be useful, as asked https://github.com/google-deepmind/spectral_ssm/issues/1