huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.06k stars 26.81k forks source link

[Falcon Mamba] Unexpected model output with use_cache=False and model.train() #33234

Closed jploski closed 1 month ago

jploski commented 1 month ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

#!/usr/bin/env python3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = './falcon-mamba-7b-instruct'
#model_id = './mamba-130m-hf'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
model.tokenizer = tokenizer

input_ids = tokenizer("A falcon mamba is a", return_tensors="pt").input_ids.to("cuda")

def test_generate(use_cache, train):
    model.train(mode=train)

    print(f"test_generate with use_cache={use_cache} and model.training={train}")

    outputs = model.generate(input_ids, do_sample=False, max_new_tokens=1, return_dict_in_generate=True, output_logits=True, use_cache=use_cache)

    logits = outputs['logits'][0]
    logits_sorted, indices = torch.sort(logits, descending=True)
    for i in range(10):
        tok = tokenizer.decode([ indices[-1][i].item() ])
        print(f"{indices[-1][i].item()} {logits_sorted[-1][i].item()} [{tok}]")

test_generate(True, False)
test_generate(False, False)
test_generate(True, True)
test_generate(False, True)

Observed (unexpected) output:

1842 137.0 [ type]
22277 136.0 [ snake]
27184 136.0 [ fictional]
45271 136.0 [ venom]
829 135.0 [ very]
2794 135.0 [ fast]
3766 135.0 [ highly]
6114 135.0 [ species]
12318 135.0 [ hybrid]
17846 135.0 [ deadly]
test_generate with use_cache=False and model.training=False
1842 137.0 [ type]
22277 136.0 [ snake]
27184 136.0 [ fictional]
45271 136.0 [ venom]
829 135.0 [ very]
2794 135.0 [ fast]
3766 135.0 [ highly]
6114 135.0 [ species]
12318 135.0 [ hybrid]
17846 135.0 [ deadly]
test_generate with use_cache=True and model.training=True
1842 137.0 [ type]
22277 136.0 [ snake]
27184 136.0 [ fictional]
45271 136.0 [ venom]
829 135.0 [ very]
2794 135.0 [ fast]
3766 135.0 [ highly]
6114 135.0 [ species]
12318 135.0 [ hybrid]
17846 135.0 [ deadly]
test_generate with use_cache=False and model.training=True
204 118.5 [ ]
627 118.0 [ new]
1842 118.0 [ type]
4307 117.5 [ gift]
241 117.0 [ a]
822 117.0 [ good]
914 117.0 [ great]
986 117.0 [ high]
1733 117.0 [ special]
829 116.5 [ very]

Expected behavior

I would expect the logits produced by forward pass to be the same regardless of whether model.train(True) has been invoked or use_cache is False or True. This holds true for mamba-130m-hf. However, when running the provided test script with falcon-mamba-7b-instruct, the output for the case model.train(True) and use_cache=False is different from the other outputs.

Practical relevance: discovered during ORPO training - where model.training is True and use_cache=False. The outputs and loss calculation during training do not match the outputs observed after training in evaluation mode.

vasqu commented 1 month ago

33195 should resolve this, the original kernels in the mamba-ssm package miss some additional rms normalization.

jploski commented 1 month ago

33195 should resolve this, the original kernels in the mamba-ssm package miss some additional rms normalization.

Thanks for the link. I tried with that PR, the output changes/improves, but it is still not identical with the other three cases:

test_generate with use_cache=False and model.training=True
1842 141.0 [ type]
22277 140.0 [ snake]
45271 140.0 [ venom]
1902 139.0 [ large]
2794 139.0 [ fast]
6114 139.0 [ species]
829 138.0 [ very]
916 138.0 [ long]
1385 138.0 [ small]
1462 138.0 [ non]
vasqu commented 1 month ago

cc @younesbelkada

younesbelkada commented 1 month ago

@jploski the output is not exactly the same but looks coherent compared to the previous results:

204 118.5 [ ]
627 118.0 [ new]
1842 118.0 [ type]
4307 117.5 [ gift]
241 117.0 [ a]
822 117.0 [ good]
914 117.0 [ great]
986 117.0 [ high]
1733 117.0 [ special]
829 116.5 [ very]

vs

1842 141.0 [ type]
22277 140.0 [ snake]
45271 140.0 [ venom]
1902 139.0 [ large]
2794 139.0 [ fast]
6114 139.0 [ species]
829 138.0 [ very]
916 138.0 [ long]
1385 138.0 [ small]
1462 138.0 [ non]

I'd suspect there are some numerical differences between the kernel and the non-kernel path which sums up and ultimately ends up not predicting exactly the same tokens.

jploski commented 1 month ago

I'd suspect there are some numerical differences between the kernel and the non-kernel path which sums up and ultimately ends up not predicting exactly the same tokens.

@younesbelkada A valid point, but perplexity over Wiki text seems significantly (~16%) worse for the train=True case. I evaluated using the following script:

#!/usr/bin/env python3

import torch
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = './falcon-mamba-7b-instruct'
#model_id = './mamba-130m-hf'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)

# wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
# unzip wikitext-2-raw-v1.zip

with open("./wiki.test.raw", encoding="utf-8") as f:
    wiki_text = f.read()

input_ids = tokenizer(wiki_text, return_tensors="pt").input_ids.to("cuda")

def test_ppl(train):
    model.train(train)

    # Adapted from https://huggingface.co/docs/transformers/en/perplexity
    stride = 512
    seq_len = input_ids.size(1)
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + stride, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        chunk_input_ids = input_ids[0, begin_loc:end_loc].unsqueeze(0).to('cuda')
        target_ids = chunk_input_ids.clone()
        target_ids[:,:-trg_len] = -100

        with torch.no_grad():
            outputs = model(chunk_input_ids, labels=target_ids, use_cache=False)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    print(f"train={train}, PPL={ppl.item()}")

test_ppl(False)
test_ppl(True)

With falcon-mamba-7b-instruct I'm getting PPL=7.3546 for train=False and PPL=8.5525 for train=True.

With mamba-130m-hf there is only a slight difference - PPL=27.1124 for train=False and PPL=27.1174 for train=True.