I initially saw this error in some code I was running. I went back to the referenced colab notebook and confirmed that the issue was also present there.
below is the chunk of code related to the issue:
`from pydantic import BaseModel
from IPython.display import display, Markdown
from typing import List
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
DEFAULT_MAX_NEW_TOKENS = 100
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '
question_with_schema = f'{question}{AnswerFormat.schema_json()}'
I initially saw this error in some code I was running. I went back to the referenced colab notebook and confirmed that the issue was also present there.
below is the chunk of code related to the issue:
`from pydantic import BaseModel from IPython.display import display, Markdown from typing import List
def display_header(text): display(Markdown(f'{text}'))
def display_content(text): display(Markdown(f'
\n{text}\n
'))DEFAULT_SYSTEM_PROMPT = """\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\ """ DEFAULT_MAX_NEW_TOKENS = 100
class AnswerFormat(BaseModel): first_name: str last_name: str year_of_birth: int num_seasons_in_nba: int
question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: ' question_with_schema = f'{question}{AnswerFormat.schema_json()}'
display_header("Question:") display_content(question_with_schema)
display_header("Answer, With json schema enforcing:") result, enforced_scores = run(question_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_schema=AnswerFormat.schema()) display_content(result)
displayheader("Answer, Without json schema enforcing:") result, = run(question_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS) display_content(result)
displayheader("Answer, With json mode enforcing (json output, schemaless):") result, = run(question_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_output=True) display_content(result)`
TypeError Traceback (most recent call last) in <cell line: 29>()
27
28 display_header("Answer, With json schema enforcing:")
---> 29 result, enforced_scores = run(question_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_schema=AnswerFormat.schema())
30 display_content(result)
31
3 frames /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs) 1898 # 11. prepare logits warper 1899 prepared_logits_warper = ( -> 1900 self._get_logits_warper(generation_config, device=input_ids.device) 1901 if generation_config.do_sample 1902 else None
TypeError: LogitsSaverManager.replace_logits_warper..new_logits_warper() got an unexpected keyword argument 'device'