uiuc-focal-lab / syncode

Efficient and general syntactical decoding for Large Language Models
MIT License
198 stars 16 forks source link

GenerationMixin._get_logits_warper() missing 1 required positional argument: 'device' #84

Closed XZF0 closed 4 months ago

XZF0 commented 4 months ago

(Looks extremely interesting - really looking forward to trying it out :)

On an M1 Mac, no CUDA GPU, setting 'device'='cpu':

Python 3.12.4
datasets                  2.20.0                   pypi_0    pypi
fire                      0.6.0                    pypi_0    pypi
interegular               0.3.3                    pypi_0    pypi
jsonschema                4.23.0             pyhd8ed1ab_0    conda-forge
jsonschema-specifications 2023.12.1          pyhd8ed1ab_0    conda-forge
jsonschema-with-format-nongpl 4.23.0               hd8ed1ab_0    conda-forge
python-fastjsonschema     2.20.0             pyhd8ed1ab_0    conda-forge
torch                     2.3.1                    pypi_0    pypi
tqdm                      4.66.4                   pypi_0    pypi
transformers              4.42.3                   pypi_0    pypi
TypeError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 prompt = "Give me the SQL query to select the name of the employee with the highest salary from the employee table. Given that the employee table has the following columns: name, salary.\n"
----> 2 output = llm.infer(prompt)[0]
      3 print(f"LLM output:\n{output}\n")

File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/infer.py:144, in Syncode.infer(self, prompt, task_id, stop_words)
    142     output = FOLEval.run_eval(self, debug_task_id=task_id)
    143 elif self.dataset.type == "input":
--> 144     output = self.user_input(prompt, stop_words=stop_words)
    145 elif self.dataset.type == "json":
    146     output = JSONEval.run_json_eval(self, debug_task_id=task_id, eval_type = self.json_eval_type)

File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/infer.py:180, in Syncode.user_input(self, prompt, stop_words)
    178         return self.model.generate_chat_completion_grammar(prompt)
    179     else:
--> 180         return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words)
    182 else:
    183     while True:

File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/language_model.py:97, in HuggingFaceModel.generate_batch_completion_grammar(self, prompt, batch_size, stop_words)
     95 # Generate completions
     96 if (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling
---> 97     generated_ids = self._generate(
     98         inputs, 
     99         gen_config, 
    100         gen_mode, 
    101         grammar_decoder=self.grammar_decoder,
    102         stop_criteria=stop_criteria
    103         )
    104 else:
    105     # Use generate from transformers library for other modes
    106     if stop_criteria is not None:

File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/language_model.py:154, in HuggingFaceModel._generate(self, inputs, gen_config, gen_mode, grammar_decoder, stop_criteria)
    150 """
    151 We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library.
    152 """
    153 token_ids, attention_mask, past_key_values = inputs['input_ids'], inputs['attention_mask'], None
--> 154 logit_warper = self.model._get_logits_warper(gen_config)
    155 max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
    157 while True:

TypeError: GenerationMixin._get_logits_warper() missing 1 required positional argument: 'device'
shubhamugare commented 4 months ago

Hi,

I haven't tried running on CPU since this new part in "language_model.py" was added. I think it should be easy to fix, but I'm little busy for next couple of days. Can you actually give it a try? We have to provide the argument device there explicitly I think. (If not you can also use SynCode as logit processor in example here (https://github.com/uiuc-focal-lab/syncode/blob/main/notebooks/example_logits_processor.ipynb), this would avoid all this and rely on HuggingFace generation method)

XZF0 commented 4 months ago

Sure. So this is a method in the HF transformers utils. It wanted device as string. Which in that context is available in self.device.
Strange (and looks pretty hacky) that it is not pulled from the config. Perhaps there is a good reason?

Changing line 154 in language_model.py logit_warper = self.model._get_logits_warper(gen_config, self.device)

shubhamugare commented 4 months ago

Yes, it seems they changes the function argument device to be required positional argument here in this commit.

Can you create a short PR with this change? If we pass the argument explicitly then it will probably crash on old versions of transformers library. We will also need to update requirements.txt to have transformers >= certain version.