huggingface / local-gemma

Gemma 2 optimized for your local machine.
Apache License 2.0
344 stars 27 forks source link

Feature req: Support for Gemma-2-2b and Gemma-2-2b-it #35

Closed notnotrishi closed 3 months ago

notnotrishi commented 3 months ago

Doesn't seem to work for these new models currently. I'm on a Mac and get the following error for these models: ValueError: Can't infer missing attention mask on mps device. Please provide an attention_mask or use a different device.

sanchit-gandhi commented 3 months ago

Hey @notnotrishi! Could you confirm that you're running the latest version of Local Gemma from main? E.g. with:

pip install --upgrade git+https://github.com/huggingface/local-gemma

I've tested the lib on an M1 Mac and am able to run the code without errors! E.g.

$ local-gemma --model "google/gemma-2-2b-it"  --preset "auto" --seed 0 --max_new_tokens 1024

Loading model with the following characteristics:
- Model name: gg-tt/gemma-2.0-2b-it-transformers
- Assistant model name: None
- Device: mps
- Default data type: torch.bfloat16
- Optimization preset: exact
- Generation arguments: {'do_sample': True, 'temperature': 0.7}
- Base prompt: None

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 14.62it/s]

You can now interact with the model through a conversation. A few tips:
- Initialize the program with '--silent' to hide all non-model messages
- Input '!exit' to leave the program
- Input '!new session' to reset the conversation
- Input '!help' to print this message again

>>> Hey!
Hey there! 👋  How can I help you today? 😊 
<end_of_turn>
>>> 
sanchit-gandhi commented 3 months ago

(By the way, performance is pretty decent for the 2b model on mps. I'm getting approx 6 tok/s)

notnotrishi commented 3 months ago

thanks @sanchit-gandhi! it seems to work from CLI but getting that error with the example code using python/transformers as described on the repo

sanchit-gandhi commented 3 months ago

I've updated the README example to prepare an attention mask:

from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer

model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-2b-it", preset="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)

generated_ids = model.generate(**model_inputs.to(model.device), max_new_tokens=1024, do_sample=True)
decoded_text = tokenizer.batch_decode(generated_ids)
notnotrishi commented 3 months ago

thanks @sanchit-gandhi for the prompt responses and the fix! 🙏