arcee-ai / mergekit

Tools for merging pretrained large language models.
GNU Lesser General Public License v3.0
4.8k stars 437 forks source link

RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' #175

Closed asphytheghoul closed 8 months ago

asphytheghoul commented 8 months ago

Hello, i am trying to use mergekit to make a merged model using llama-2 models that i have trained. This is the config.yaml file i am using . i am using the dare-ties algorithm. Please help me in this regard.

models:
  - model: AsphyXIA/baarat-hindi-pretrained
    # No parameters necessary for base model
  - model: AsphyXIA/baarat-MTH
    parameters:
      density: 0.53
      weight: 0.4
  - model: AsphyXIA/baarat-hin-summarization
    parameters:
      density: 0.53
      weight: 0.3
  - model: AsphyXIA/baarat-hindi-qa
    parameters:
      density: 0.53
      weight: 0.5
merge_method: dare_ties
base_model: AsphyXIA/baarat-hindi-pretrained
parameters:
  int8_mask: true
dtype: bfloat16

I run using the following command : mergekit-yaml config.yaml merge --copy-tokenizer --cuda --trust-remote-code

I run inference using the following code :

from transformers import AutoTokenizer
import transformers
import torch
model = "merge/" #If you want to test your own model, replace this value with your model directory path, e.g., "merge/" if didn't renamed the directory where you merged the model
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": False},
)
# messages = [{"role": "user", "content": "Do you know how to cook pasta?"}]
# prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt= "Do you know how to cook pasta?"
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])

This is the full error :

RuntimeError                              Traceback (most recent call last)
Cell In[9], line 14
     11 # messages = [{"role": "user", "content": "Do you know how to cook pasta?"}]
     12 # prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
     13 prompt= "Do you know how to cook pasta?"
---> 14 outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
     15 print(outputs[0]["generated_text"])

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/text_generation.py:241](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/text_generation.py#line=240), in TextGenerationPipeline.__call__(self, text_inputs, **kwargs)
    239         return super().__call__(chats, **kwargs)
    240 else:
--> 241     return super().__call__(text_inputs, **kwargs)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/base.py:1196](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/base.py#line=1195), in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1188     return next(
   1189         iter(
   1190             self.get_iterator(
   (...)
   1193         )
   1194     )
   1195 else:
-> 1196     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/base.py:1203](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/base.py#line=1202), in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1201 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
   1202     model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1203     model_outputs = self.forward(model_inputs, **forward_params)
   1204     outputs = self.postprocess(model_outputs, **postprocess_params)
   1205     return outputs

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/base.py:1102](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/base.py#line=1101), in Pipeline.forward(self, model_inputs, **forward_params)
   1100     with inference_context():
   1101         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1102         model_outputs = self._forward(model_inputs, **forward_params)
   1103         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
   1104 else:

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/text_generation.py:328](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/pipelines/text_generation.py#line=327), in TextGenerationPipeline._forward(self, model_inputs, **generate_kwargs)
    325         generate_kwargs["min_length"] += prefix_length
    327 # BS x SL
--> 328 generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
    329 out_b = generated_sequence.shape[0]
    330 if self.framework == "pt":

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/generation/utils.py:1592](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/generation/utils.py#line=1591), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1584     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1585         input_ids=input_ids,
   1586         expand_size=generation_config.num_return_sequences,
   1587         is_encoder_decoder=self.config.is_encoder_decoder,
   1588         **model_kwargs,
   1589     )
   1591     # 13. run sample
-> 1592     return self.sample(
   1593         input_ids,
   1594         logits_processor=prepared_logits_processor,
   1595         logits_warper=logits_warper,
   1596         stopping_criteria=prepared_stopping_criteria,
   1597         pad_token_id=generation_config.pad_token_id,
   1598         eos_token_id=generation_config.eos_token_id,
   1599         output_scores=generation_config.output_scores,
   1600         output_logits=generation_config.output_logits,
   1601         return_dict_in_generate=generation_config.return_dict_in_generate,
   1602         synced_gpus=synced_gpus,
   1603         streamer=streamer,
   1604         **model_kwargs,
   1605     )
   1607 elif generation_mode == GenerationMode.BEAM_SEARCH:
   1608     # 11. prepare beam search scorer
   1609     beam_scorer = BeamSearchScorer(
   1610         batch_size=batch_size,
   1611         num_beams=generation_config.num_beams,
   (...)
   1616         max_length=generation_config.max_length,
   1617     )

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/generation/utils.py:2696](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/generation/utils.py#line=2695), in GenerationMixin.sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, output_logits, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2693 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2695 # forward pass to get next token
-> 2696 outputs = self(
   2697     **model_inputs,
   2698     return_dict=True,
   2699     output_attentions=output_attentions,
   2700     output_hidden_states=output_hidden_states,
   2701 )
   2703 if synced_gpus and this_peer_finished:
   2704     continue  # don't waste resources running the code we don't need

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1168](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=1167), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1165 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1167 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1168 outputs = self.model(
   1169     input_ids=input_ids,
   1170     attention_mask=attention_mask,
   1171     position_ids=position_ids,
   1172     past_key_values=past_key_values,
   1173     inputs_embeds=inputs_embeds,
   1174     use_cache=use_cache,
   1175     output_attentions=output_attentions,
   1176     output_hidden_states=output_hidden_states,
   1177     return_dict=return_dict,
   1178     cache_position=cache_position,
   1179 )
   1181 hidden_states = outputs[0]
   1182 if self.config.pretraining_tp > 1:

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1008](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=1007), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    997     layer_outputs = self._gradient_checkpointing_func(
    998         decoder_layer.__call__,
    999         hidden_states,
   (...)
   1005         cache_position,
   1006     )
   1007 else:
-> 1008     layer_outputs = decoder_layer(
   1009         hidden_states,
   1010         attention_mask=causal_mask,
   1011         position_ids=position_ids,
   1012         past_key_value=past_key_values,
   1013         output_attentions=output_attentions,
   1014         use_cache=use_cache,
   1015         cache_position=cache_position,
   1016     )
   1018 hidden_states = layer_outputs[0]
   1020 if use_cache:

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:734](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=733), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    731 hidden_states = self.input_layernorm(hidden_states)
    733 # Self Attention
--> 734 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    735     hidden_states=hidden_states,
    736     attention_mask=attention_mask,
    737     position_ids=position_ids,
    738     past_key_value=past_key_value,
    739     output_attentions=output_attentions,
    740     use_cache=use_cache,
    741     cache_position=cache_position,
    742     **kwargs,
    743 )
    744 hidden_states = residual + hidden_states
    746 # Fully Connected

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:347](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=346), in LlamaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    344     value_states = torch.cat(value_states, dim=-1)
    346 else:
--> 347     query_states = self.q_proj(hidden_states)
    348     key_states = self.k_proj(hidden_states)
    349     value_states = self.v_proj(hidden_states)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [~/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/linear.py:114](http://localhost:8888/home/baarat/anaconda3/envs/baarat/lib/python3.10/site-packages/torch/nn/modules/linear.py#line=113), in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
cg123 commented 8 months ago

This error is occurring because your inference code is trying to run the model in fp16 on the CPU, which transformers/pytorch do not support. If you have a GPU on the machine you're running this on, I'd recommend using it - you can do that by changing model_kwargs like so:

    model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": False, "device_map": "auto"},

If you don't have a GPU then you might be able to run the model in either fp32 or bf16 precision. Try like so:

    model_kwargs={"torch_dtype": torch.float32, "load_in_4bit": False}, # or torch.bfloat16