thunlp / InfLLM

The code of our paper "InfLLM: Unveiling the Intrinsic Capacity of LLMs for Understanding Extremely Long Sequences with Training-Free Memory"
MIT License
269 stars 21 forks source link

How to use w transformers? #30

Open thistleknot opened 4 months ago

thistleknot commented 4 months ago

I use transformers with a custom script, I see you show how to use this with a custom fast chat script

Do you have boilerplate code on how to wrap a transformers pipeline to use w this?

guyan364 commented 4 months ago

You can use patch_hf for transformers. For this usage, you can refer to the integration in chat.py. Load configuration as a dict, and pass it to the patch_hf with your model.

from inf_llm.utils import patch_hf
config = load_yaml_config()['model']
model = patch_hf(model, config['type'], **config)
thistleknot commented 4 months ago

"load_yaml_config"?

I tried

import yaml
from inf_llm.utils import patch_hf

# Simulated YAML configuration as a string
config_string = """
model:
  type: mistral  # Assuming 'mistral' is a valid type for your use case
  path: mistralai/Mistral-7B-Instruct-v0.2
  block_size: 128
  n_init: 128
  n_local: 4096
  topk: 16
  repr_topk: 4
  max_cached_block: 32
  exc_block_size: 512
  fattn: false
  base: 1000000
  distance_scale: 1.0
max_len: 2147483647
chunk_size: 8192
conv_type: mistral-inst

"""
# Parsing the YAML string into a dictionary
config = yaml.safe_load(config_string)

# Assuming 'model' is a pre-initialized model object
model = None  # Replace this with your actual model initialization logic

# Extracting type and the rest of the model configuration
model_type = config['model'].pop('type')

# Applying the configuration to the model
patched_model = patch_hf(model, model_type, **config['model'])

# Output for debugging
print("Model patched with the following configuration:")
print(config['model'])

and I get an error about only certain model's are supported, mistral, llama, etc.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[21], line 35
     32 model_type = config['model'].pop('type')
     34 # Applying the configuration to the model
---> 35 patched_model = patch_hf(model, model_type, **config['model'])
     37 # Output for debugging
     38 print("Model patched with the following configuration:")

File /data/InfLLM/inf_llm/utils/patch.py:133, in patch_hf(model, attn_type, attn_kwargs, base, distance_scale, **kwargs)
    125         return tuple(v for v in [hidden_states, pkv, all_hidden_states, all_self_attns] if v is not None)
    126     return BaseModelOutputWithPast(
    127         last_hidden_state=hidden_states,
    128         past_key_values=pkv,
    129         hidden_states=all_hidden_states,
    130         attentions=all_self_attns,
    131     )
--> 133 forward = huggingface_forward(ATTN_FORWRAD[attn_type](**attn_kwargs))
    135 if isinstance(model, LlamaForCausalLM):
    136     Attention = LlamaAttention

KeyError: 'mistral'
chris-aeviator commented 4 months ago

while I can patch the model it won't work with the standard HF tools. Transformers assumes past_key_values to be subscriptable but past_key_values is a ContextManager


model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

generated_ids = patched_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0]

in MistralForCausalLM.prepare_inputs_for_generation(self, input_ids, past_key_values, attention_mask, inputs_embeds, **kwargs)
   1206     max_cache_length = past_key_values.get_max_length()
   1207 else:
-> 1208     cache_length = past_length = past_key_values[0][0].shape[2]
   1209     max_cache_length = None
   1211 # Keep only the unprocessed tokens:
   1212 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
   1213 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
   1214 # input)

TypeError: 'ContextManager' object is not subscriptable
thistleknot commented 4 months ago

import yaml from inf_llm.utils import patch_hf from transformers import AutoModel

def load_yaml_config(file_path='path_to_your_config_file.yaml'): """ Load a YAML configuration file. """ with open(file_path, 'r') as file: return yaml.safe_load(file)

Load the configuration for infinite context

config_path = 'minicpm-inf-llm.yaml' with open(config_path, 'r') as file: inf_llm_config = yaml.safe_load(file) inf_llm_config

from inf_llm.utils import patch_hf config = load_yaml_config(file_path=config_path)['model'] model = patch_hf(model, config['type'], **config)

produces

ValueError Traceback (most recent call last) Cell In[26], line 3 1 from inf_llm.utils import patch_hf 2 config = load_yaml_config(file_path=config_path)['model'] ----> 3 model = patch_hf(model, config['type'], **config)

File /home/user/mamba/InfLLM/inf_llm/utils/patch.py:150, in patch_hf(model, attn_type, attn_kwargs, base, distance_scale, **kwargs) 148 Model = model.model.class 149 else: --> 150 raise ValueError("Only supports llama, mistral and qwen2 models.") 152 hf_rope = model.model.layers[0].self_attn.rotary_emb 153 base = base if base is not None else hf_rope.base

ValueError: Only supports llama, mistral and qwen2 models.