Closed jploski closed 1 month ago
33195 should resolve this, the original kernels in the mamba-ssm package miss some additional rms normalization.
Thanks for the link. I tried with that PR, the output changes/improves, but it is still not identical with the other three cases:
test_generate with use_cache=False and model.training=True
1842 141.0 [ type]
22277 140.0 [ snake]
45271 140.0 [ venom]
1902 139.0 [ large]
2794 139.0 [ fast]
6114 139.0 [ species]
829 138.0 [ very]
916 138.0 [ long]
1385 138.0 [ small]
1462 138.0 [ non]
cc @younesbelkada
@jploski the output is not exactly the same but looks coherent compared to the previous results:
204 118.5 [ ]
627 118.0 [ new]
1842 118.0 [ type]
4307 117.5 [ gift]
241 117.0 [ a]
822 117.0 [ good]
914 117.0 [ great]
986 117.0 [ high]
1733 117.0 [ special]
829 116.5 [ very]
vs
1842 141.0 [ type]
22277 140.0 [ snake]
45271 140.0 [ venom]
1902 139.0 [ large]
2794 139.0 [ fast]
6114 139.0 [ species]
829 138.0 [ very]
916 138.0 [ long]
1385 138.0 [ small]
1462 138.0 [ non]
I'd suspect there are some numerical differences between the kernel and the non-kernel path which sums up and ultimately ends up not predicting exactly the same tokens.
I'd suspect there are some numerical differences between the kernel and the non-kernel path which sums up and ultimately ends up not predicting exactly the same tokens.
@younesbelkada A valid point, but perplexity over Wiki text seems significantly (~16%) worse for the train=True case. I evaluated using the following script:
#!/usr/bin/env python3
import torch
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = './falcon-mamba-7b-instruct'
#model_id = './mamba-130m-hf'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
# wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
# unzip wikitext-2-raw-v1.zip
with open("./wiki.test.raw", encoding="utf-8") as f:
wiki_text = f.read()
input_ids = tokenizer(wiki_text, return_tensors="pt").input_ids.to("cuda")
def test_ppl(train):
model.train(train)
# Adapted from https://huggingface.co/docs/transformers/en/perplexity
stride = 512
seq_len = input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + stride, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
chunk_input_ids = input_ids[0, begin_loc:end_loc].unsqueeze(0).to('cuda')
target_ids = chunk_input_ids.clone()
target_ids[:,:-trg_len] = -100
with torch.no_grad():
outputs = model(chunk_input_ids, labels=target_ids, use_cache=False)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
# to the left by 1.
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
print(f"train={train}, PPL={ppl.item()}")
test_ppl(False)
test_ppl(True)
With falcon-mamba-7b-instruct I'm getting PPL=7.3546 for train=False and PPL=8.5525 for train=True.
With mamba-130m-hf there is only a slight difference - PPL=27.1124 for train=False and PPL=27.1174 for train=True.
System Info
transformers
version: 4.45.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Observed (unexpected) output:
Expected behavior
I would expect the logits produced by forward pass to be the same regardless of whether model.train(True) has been invoked or use_cache is False or True. This holds true for mamba-130m-hf. However, when running the provided test script with falcon-mamba-7b-instruct, the output for the case model.train(True) and use_cache=False is different from the other outputs.
Practical relevance: discovered during ORPO training - where model.training is True and use_cache=False. The outputs and loss calculation during training do not match the outputs observed after training in evaluation mode.