Open younesbelkada opened 1 month ago
A simple snippet to reproduce the current issue:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FalconMambaForCausalLM
model_id = "tiiuae/falcon-mamba-7b"
text = "Hello today we are going to"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(model_id)
inputs = tok(text, return_tensors="pt").to(0)
with torch.no_grad():
logits = torch.argmax(model(**inputs).logits, dim=-1)
print(tok.batch_decode(logits))
model.train()
lm_logits = model(**inputs).logits
next_token = torch.argmax(lm_logits, dim=-1)
print(tok.batch_decode(logits))
loss = (1 - lm_logits).mean()
loss.backward()
Hi @tridao @albertfgu I made an alternative PR in HF transformers: https://github.com/huggingface/transformers/pull/33195 where I simply copied over the kernels there. Let me know if you see any issue potentially merging this PR in mamba-ssm - thanks !
Hi Albert Gu and Tri Dao,
First of all, thank you for this package. We would like to upstream some changes that were needed to train the FalconMamba-7B model using the mamba kernels.
This PR introduces a way to pass non learnable RMS norm weights in order to normalize B, C and dt states as per our training procedure.
Another way could be to initialize
weight
inrms_norm_forward
withtorch.ones_like
, but I'd prefer to force users to pass the non learnable parameters themselves to avoid multiple tensor initialization at each call ofmamba_inner_fn
, there might be a way to call the rms norm forward without having the need to pass RMS weights which I am not sure.On transformers side, we would call the interface with the following:
Thank you very much in advance ! @tridao @albertfgu