meta-math / MetaMath

MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models
https://meta-math.github.io
Apache License 2.0
388 stars 35 forks source link

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 when setting max_new_tokens #7

Closed AegeanYan closed 1 year ago

AegeanYan commented 1 year ago

I found this bug with following reproduction.

import torch
import sys
import random
import numpy as np
from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig, GenerationConfig
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    # bnb_4bit_quant_type="fp4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
device = "cuda:0"
tokenizer = LlamaTokenizer.from_pretrained("MetaMath-7B-V1.0",legacy=False)
model = LlamaForCausalLM.from_pretrained(
        "MetaMath-7B-V1.0",
        quantization_config=bnb_config,
        device_map="auto",
    )
model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
generation_config = GenerationConfig(
                temperature=0.8,
                max_new_tokens=512,###here is the problem
                do_sample=True,
                top_p=0.95,
                early_stopping=True,
            )
model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
eos_token_id = -100
input = "Her eyes are beautiful."
tokens = tokenizer([input]*10, return_tensors='pt', padding=True).to(device)
with torch.inference_mode():
    output = model.generate(**tokens, generation_config=generation_config, return_dict_in_generate=True)
decoded = tokenizer.batch_decode(output.sequences, skip_special_tokens=True)
print(decoded)

when setting the max_new_tokens I will get the tensor error, comment it would be fine. Could you please check that? My transformer version is 4.33.3

AegeanYan commented 1 year ago

I think this is a batch problem, single input do not trigger the error