Open thistleknot opened 4 months ago
You can use patch_hf for transformers. For this usage, you can refer to the integration in chat.py. Load configuration as a dict, and pass it to the patch_hf with your model.
from inf_llm.utils import patch_hf
config = load_yaml_config()['model']
model = patch_hf(model, config['type'], **config)
"load_yaml_config"?
I tried
import yaml
from inf_llm.utils import patch_hf
# Simulated YAML configuration as a string
config_string = """
model:
type: mistral # Assuming 'mistral' is a valid type for your use case
path: mistralai/Mistral-7B-Instruct-v0.2
block_size: 128
n_init: 128
n_local: 4096
topk: 16
repr_topk: 4
max_cached_block: 32
exc_block_size: 512
fattn: false
base: 1000000
distance_scale: 1.0
max_len: 2147483647
chunk_size: 8192
conv_type: mistral-inst
"""
# Parsing the YAML string into a dictionary
config = yaml.safe_load(config_string)
# Assuming 'model' is a pre-initialized model object
model = None # Replace this with your actual model initialization logic
# Extracting type and the rest of the model configuration
model_type = config['model'].pop('type')
# Applying the configuration to the model
patched_model = patch_hf(model, model_type, **config['model'])
# Output for debugging
print("Model patched with the following configuration:")
print(config['model'])
and I get an error about only certain model's are supported, mistral, llama, etc.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[21], line 35
32 model_type = config['model'].pop('type')
34 # Applying the configuration to the model
---> 35 patched_model = patch_hf(model, model_type, **config['model'])
37 # Output for debugging
38 print("Model patched with the following configuration:")
File /data/InfLLM/inf_llm/utils/patch.py:133, in patch_hf(model, attn_type, attn_kwargs, base, distance_scale, **kwargs)
125 return tuple(v for v in [hidden_states, pkv, all_hidden_states, all_self_attns] if v is not None)
126 return BaseModelOutputWithPast(
127 last_hidden_state=hidden_states,
128 past_key_values=pkv,
129 hidden_states=all_hidden_states,
130 attentions=all_self_attns,
131 )
--> 133 forward = huggingface_forward(ATTN_FORWRAD[attn_type](**attn_kwargs))
135 if isinstance(model, LlamaForCausalLM):
136 Attention = LlamaAttention
KeyError: 'mistral'
while I can patch the model it won't work with the standard HF tools. Transformers assumes past_key_values to be subscriptable but past_key_values is a ContextManager
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
generated_ids = patched_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0]
in MistralForCausalLM.prepare_inputs_for_generation(self, input_ids, past_key_values, attention_mask, inputs_embeds, **kwargs)
1206 max_cache_length = past_key_values.get_max_length()
1207 else:
-> 1208 cache_length = past_length = past_key_values[0][0].shape[2]
1209 max_cache_length = None
1211 # Keep only the unprocessed tokens:
1212 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1213 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1214 # input)
TypeError: 'ContextManager' object is not subscriptable
import yaml from inf_llm.utils import patch_hf from transformers import AutoModel
def load_yaml_config(file_path='path_to_your_config_file.yaml'): """ Load a YAML configuration file. """ with open(file_path, 'r') as file: return yaml.safe_load(file)
config_path = 'minicpm-inf-llm.yaml' with open(config_path, 'r') as file: inf_llm_config = yaml.safe_load(file) inf_llm_config
from inf_llm.utils import patch_hf config = load_yaml_config(file_path=config_path)['model'] model = patch_hf(model, config['type'], **config)
ValueError Traceback (most recent call last) Cell In[26], line 3 1 from inf_llm.utils import patch_hf 2 config = load_yaml_config(file_path=config_path)['model'] ----> 3 model = patch_hf(model, config['type'], **config)
File /home/user/mamba/InfLLM/inf_llm/utils/patch.py:150, in patch_hf(model, attn_type, attn_kwargs, base, distance_scale, **kwargs) 148 Model = model.model.class 149 else: --> 150 raise ValueError("Only supports llama, mistral and qwen2 models.") 152 hf_rope = model.model.layers[0].self_attn.rotary_emb 153 base = base if base is not None else hf_rope.base
ValueError: Only supports llama, mistral and qwen2 models.
I use transformers with a custom script, I see you show how to use this with a custom fast chat script
Do you have boilerplate code on how to wrap a transformers pipeline to use w this?