uiuc-focal-lab / syncode

Efficient and general syntactical decoding for Large Language Models
MIT License
200 stars 16 forks source link

"Generate Indentation-Error-Free Python Code" example with "microsoft/phi-2" model generates error #78

Closed border-b closed 6 months ago

border-b commented 6 months ago

Tried to run the example of generating indentation-error-free python code by just replacing the model with the "microsoft/phi-2" model.

from syncode import Syncode

model_name = "microsoft/phi-2"

# Load the Syncode augmented model
syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python')
partial_code = "def is_prime(n):\n    '''Return if prime'''\n  "

#generate a completion to the input partial code
constrained_output = partial_code+ syn_llm.infer(partial_code)[0]
print(constrained_output)

This generates the following error:

---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[10], line 6
      3 model_name = "microsoft/phi-2"
      5 # Load the Syncode augmented model
----> 6 syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python')
      7 partial_code = "def is_prime(n):\n    '''Return if prime'''\n  "
      9 #generate a completion to the input partial code

File /usr/local/lib/python3.10/site-packages/syncode/infer.py:108, in Syncode.__init__(self, model, mode, quantize, device, num_samples, grammar, dataset, num_few_shot, chat_mode, parse_output_only, dev_mode, log_level, new_mask_store, parser, task_id, json_eval_type, **kwargs)
    105 self.grammar_decoder = None
    107 if self.is_grammar_mode():
--> 108     self.grammar_decoder = SyncodeLogitsProcessor(
    109         self.grammar, 
    110         tokenizer=tokenizer, 
    111         logger=self.logger, 
    112         use_cache=(not self.new_mask_store), 
    113         parse_output_only=self.parse_output_only,
    114         num_samples=self.num_samples, 
    115         dev_mode=dev_mode,
    116         parser=parser,
    117         mode=mode,
    118         )
    120 # Set LLM generation args e.g. max_new_tokens, do_sample, etc.
    121 self.set_generation_args(kwargs, tokenizer)

File /usr/local/lib/python3.10/site-packages/syncode/grammar_decoder.py:52, in SyncodeLogitsProcessor.__init__(self, grammar, tokenizer, logger, use_cache, parse_output_only, num_samples, dev_mode, parser, mode)
     49 self.start_from = None         
     51 # Ignore whitespace tokens
---> 52 self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)
     54 # Load dfa mask store
     55 self.dfa_mask_store = DFAMaskStore.load_dfa_mask_store(
     56                             grammar=self.grammar, 
     57                             tokenizer=self.tokenizer, 
   (...)
     60                             mode=mode,
     61                             )

File /usr/local/lib/python3.10/site-packages/syncode/grammar_decoder.py:79, in SyncodeLogitsProcessor._get_ignore_whitespace(self, grammar)
     75 def _get_ignore_whitespace(self, grammar):
     76     """
     77     Check if the grammar allows whitespace tokens to be ignored.
     78     """
---> 79     base_parser = create_base_parser(grammar)
     80     terminals = base_parser.terminals
     81     ignore_terminals = base_parser.ignore_tokens

File /usr/local/lib/python3.10/site-packages/syncode/parsers/__init__.py:31, in create_base_parser(grammar, parser, indenter, cache_filename)
     30 def create_base_parser(grammar, parser='lalr', indenter=None, cache_filename=None):
---> 31     base_parser = Lark( # This is the standard Lark parser
     32                         grammar.ebnf,
     33                         parser=parser,
     34                         lexer="basic",
     35                         start="start",
     36                         postlex=indenter,
     37                         propagate_positions=True,
     38                         cache = cache_filename
     39                     )
     41     return base_parser

File /usr/local/lib/python3.10/site-packages/syncode/larkm/lark.py:362, in Lark.__init__(self, grammar, **options)
    358             self.options = old_options
    361     # Parse the grammar file and compose the grammars
--> 362     self.grammar, used_files = load_grammar(grammar, self.source_path, self.options.import_paths, self.options.keep_all_tokens)
    363 else:
    364     assert isinstance(grammar, Grammar)

File /usr/local/lib/python3.10/site-packages/syncode/larkm/load_grammar.py:1415, in load_grammar(grammar, source, import_paths, global_keep_all_tokens)
   1413 def load_grammar(grammar, source, import_paths, global_keep_all_tokens):
   1414     builder = GrammarBuilder(global_keep_all_tokens, import_paths)
-> 1415     builder.load_grammar(grammar, source)
   1416     return builder.build(), builder.used_files

File /usr/local/lib/python3.10/site-packages/syncode/larkm/load_grammar.py:1255, in GrammarBuilder.load_grammar(self, grammar_text, grammar_name, mangle)
   1252             imports[dotted_path] = base_path, aliases
   1254 for dotted_path, (base_path, aliases) in imports.items():
-> 1255     self.do_import(dotted_path, base_path, aliases, mangle)
   1257 for stmt in tree.children:
   1258     if stmt.data in ('term', 'rule'):

File /usr/local/lib/python3.10/site-packages/syncode/larkm/load_grammar.py:1338, in GrammarBuilder.do_import(self, dotted_path, base_path, aliases, base_mangle)
   1335         break
   1336 else:
   1337     # Search failed. Make Python throw a nice error.
-> 1338     open(grammar_path, encoding='utf8')
   1339     assert False, "Couldn't import grammar %s, but a corresponding file was found at a place where lark doesn't search for it" % (dotted_path,)

FileNotFoundError: [Errno 2] No such file or directory: 'common.lark'
shubhamugare commented 6 months ago

Hi @border-b ,

Someone reported a similar issue previously #61 but they closed it without confirming what was going wrong. I'll try to figure out what's happening.

Which OS are you using? Can you try running this on Colab and see if it works there?

border-b commented 6 months ago

Hey @shubhamugare!

The issue persists in Colab environment too: Link to Colab Notebook

I used Ubuntu 22.04 to run the script.

shubhamugare commented 6 months ago

Hi @border-b ,

It should be fixed now in the recent commit.

Also, note that the first time running Python generation may take about ~10 mins to compute the mask store. But the mask store is cached once and then all future runs are immediate.

Let me know if there are any other issues.

border-b commented 6 months ago

@shubhamugare This seems to be working fine with "microsoft/phi-2" and "mistralai/Mistral-7B-v0.1". But doesn't work with "meta-llama/Meta-Llama-3-8B". It gets the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 10
      7 partial_code = "def is_prime(n):\n    '''Return if prime'''\n  "
      9 #generate a completion to the input partial code
---> 10 constrained_output = partial_code+ syn_llm.infer(partial_code)[0]
     11 print(constrained_output)
     12 # def is_prime(n):
     13 #     '''Return if prime'''
     14 #     if n < 2:
   (...)
     18 #             return False
     19 #     return True

File /usr/local/lib/python3.10/site-packages/syncode/infer.py:150, in Syncode.infer(self, prompt, task_id, stop_words)
    148     output = FOLEval.run_eval(self, debug_task_id=task_id)
    149 elif self.dataset.type == "input":
--> 150     output = self.user_input(prompt, stop_words=stop_words)
    151 elif self.dataset.type == "json":
    152     output = JSONEval.run_json_eval(self, debug_task_id=task_id, eval_type = self.json_eval_type)

File /usr/local/lib/python3.10/site-packages/syncode/infer.py:186, in Syncode.user_input(self, prompt, stop_words)
    184         return self.model.generate_chat_completion_grammar(prompt)
    185     else:
--> 186         return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words)
    188 else:
    189     while True:

File /usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/syncode/language_model.py:97, in HuggingFaceModel.generate_batch_completion_grammar(self, prompt, batch_size, stop_words)
     95 # Generate completions
     96 if (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling
---> 97     generated_ids = self._generate(
     98         inputs, 
     99         gen_config, 
    100         gen_mode, 
    101         grammar_decoder=self.grammar_decoder,
    102         stop_criteria=stop_criteria
    103         )
    104 else:
    105     # Use generate from transformers library for other modes
    106     if stop_criteria is not None:

File /usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/syncode/language_model.py:164, in HuggingFaceModel._generate(self, inputs, gen_config, gen_mode, grammar_decoder, stop_criteria)
    161     else:
    162         input_ids = token_ids
--> 164     outputs = self.model(
    165         input_ids, 
    166         attention_mask=attention_mask, 
    167         past_key_values=past_key_values
    168         )                
    169 except IndexError as e:  
    170     raise ValueError(f"The input length exceeds the context length of the model. {e}")

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1211, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1208 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1210 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1211 outputs = self.model(
   1212     input_ids=input_ids,
   1213     attention_mask=attention_mask,
   1214     position_ids=position_ids,
   1215     past_key_values=past_key_values,
   1216     inputs_embeds=inputs_embeds,
   1217     use_cache=use_cache,
   1218     output_attentions=output_attentions,
   1219     output_hidden_states=output_hidden_states,
   1220     return_dict=return_dict,
   1221     cache_position=cache_position,
   1222 )
   1224 hidden_states = outputs[0]
   1225 if self.config.pretraining_tp > 1:

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:992, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    989 if position_ids is None:
    990     position_ids = cache_position.unsqueeze(0)
--> 992 causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
    994 # embed positions
    995 hidden_states = inputs_embeds

File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1095, in LlamaModel._update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens)
   1093 causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
   1094 if sequence_length != 1:
-> 1095     causal_mask = torch.triu(causal_mask, diagonal=1)
   1096 causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
   1097 causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)

RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

I think the same error occurs when running the json example with the llama model.

shubhamugare commented 6 months ago

It could be because of the transformers version I think. More recent models may require latest versions of the HF transformers library

border-b commented 6 months ago

print(transformers.__version__) gives me 4.40.2. I think this is probably the latest version (link).

shubhamugare commented 6 months ago

Can you try to set quantize=False (it is set to True by default). I guess they don't have bfloat16 support for Llama-3 yet

border-b commented 6 months ago

Tried it. Still getting the same error.

border-b commented 6 months ago
from syncode import Syncode

model_name = "meta-llama/Meta-Llama-3-8B"

# Load the Syncode augmented model
syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python', quantize=False)
partial_code = "def is_prime(n):\n    '''Return if prime'''\n  "

#generate a completion to the input partial code
constrained_output = partial_code+ syn_llm.infer(partial_code)[0]
print(constrained_output)

exec(constrained_output)
# Correct Code :)

This is the code that gets the error.

shubhamugare commented 6 months ago

Actually I see there is a bug in the quantize parameter right now. Currently, it is not being used while loading the model here. If possible would you like to make this fix? Otherwise I can get it fixed later in the day today

border-b commented 6 months ago

Sure, I can take a look into it.

border-b commented 6 months ago

@shubhamugare created a PR with the fix. I've tested it locally with this code, and seems to be working now.