Closed border-b closed 6 months ago
Hi @border-b ,
Someone reported a similar issue previously #61 but they closed it without confirming what was going wrong. I'll try to figure out what's happening.
Which OS are you using? Can you try running this on Colab and see if it works there?
Hey @shubhamugare!
The issue persists in Colab environment too: Link to Colab Notebook
I used Ubuntu 22.04 to run the script.
Hi @border-b ,
It should be fixed now in the recent commit.
Also, note that the first time running Python generation may take about ~10 mins to compute the mask store. But the mask store is cached once and then all future runs are immediate.
Let me know if there are any other issues.
@shubhamugare This seems to be working fine with "microsoft/phi-2" and "mistralai/Mistral-7B-v0.1". But doesn't work with "meta-llama/Meta-Llama-3-8B". It gets the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 10
7 partial_code = "def is_prime(n):\n '''Return if prime'''\n "
9 #generate a completion to the input partial code
---> 10 constrained_output = partial_code+ syn_llm.infer(partial_code)[0]
11 print(constrained_output)
12 # def is_prime(n):
13 # '''Return if prime'''
14 # if n < 2:
(...)
18 # return False
19 # return True
File /usr/local/lib/python3.10/site-packages/syncode/infer.py:150, in Syncode.infer(self, prompt, task_id, stop_words)
148 output = FOLEval.run_eval(self, debug_task_id=task_id)
149 elif self.dataset.type == "input":
--> 150 output = self.user_input(prompt, stop_words=stop_words)
151 elif self.dataset.type == "json":
152 output = JSONEval.run_json_eval(self, debug_task_id=task_id, eval_type = self.json_eval_type)
File /usr/local/lib/python3.10/site-packages/syncode/infer.py:186, in Syncode.user_input(self, prompt, stop_words)
184 return self.model.generate_chat_completion_grammar(prompt)
185 else:
--> 186 return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words)
188 else:
189 while True:
File /usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/syncode/language_model.py:97, in HuggingFaceModel.generate_batch_completion_grammar(self, prompt, batch_size, stop_words)
95 # Generate completions
96 if (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling
---> 97 generated_ids = self._generate(
98 inputs,
99 gen_config,
100 gen_mode,
101 grammar_decoder=self.grammar_decoder,
102 stop_criteria=stop_criteria
103 )
104 else:
105 # Use generate from transformers library for other modes
106 if stop_criteria is not None:
File /usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/syncode/language_model.py:164, in HuggingFaceModel._generate(self, inputs, gen_config, gen_mode, grammar_decoder, stop_criteria)
161 else:
162 input_ids = token_ids
--> 164 outputs = self.model(
165 input_ids,
166 attention_mask=attention_mask,
167 past_key_values=past_key_values
168 )
169 except IndexError as e:
170 raise ValueError(f"The input length exceeds the context length of the model. {e}")
File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1211, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1208 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1210 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1211 outputs = self.model(
1212 input_ids=input_ids,
1213 attention_mask=attention_mask,
1214 position_ids=position_ids,
1215 past_key_values=past_key_values,
1216 inputs_embeds=inputs_embeds,
1217 use_cache=use_cache,
1218 output_attentions=output_attentions,
1219 output_hidden_states=output_hidden_states,
1220 return_dict=return_dict,
1221 cache_position=cache_position,
1222 )
1224 hidden_states = outputs[0]
1225 if self.config.pretraining_tp > 1:
File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:992, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
989 if position_ids is None:
990 position_ids = cache_position.unsqueeze(0)
--> 992 causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
994 # embed positions
995 hidden_states = inputs_embeds
File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1095, in LlamaModel._update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens)
1093 causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1094 if sequence_length != 1:
-> 1095 causal_mask = torch.triu(causal_mask, diagonal=1)
1096 causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1097 causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
I think the same error occurs when running the json example with the llama model.
It could be because of the transformers version I think. More recent models may require latest versions of the HF transformers library
print(transformers.__version__)
gives me 4.40.2
.
I think this is probably the latest version (link).
Can you try to set quantize
=False (it is set to True
by default). I guess they don't have bfloat16 support for Llama-3 yet
Tried it. Still getting the same error.
from syncode import Syncode
model_name = "meta-llama/Meta-Llama-3-8B"
# Load the Syncode augmented model
syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python', quantize=False)
partial_code = "def is_prime(n):\n '''Return if prime'''\n "
#generate a completion to the input partial code
constrained_output = partial_code+ syn_llm.infer(partial_code)[0]
print(constrained_output)
exec(constrained_output)
# Correct Code :)
This is the code that gets the error.
Actually I see there is a bug in the quantize
parameter right now. Currently, it is not being used while loading the model here. If possible would you like to make this fix? Otherwise I can get it fixed later in the day today
Sure, I can take a look into it.
Tried to run the example of generating indentation-error-free python code by just replacing the model with the "microsoft/phi-2" model.
This generates the following error: