Run Gemma-2 locally in Python, fast!
This repository provides an easy way to run Gemma-2 locally directly from your CLI (or via a Python library) and fast. It is built on top of the 🤗 Transformers and bitsandbytes libraries.
It can be configured to give fully equivalent results to the original implementation, or reduce memory requirements down to just the largest layer in the model!
[!IMPORTANT] There is a new "speed" preset for running
local-gemma
on CUDA ⚡️ It makes use of torch compile for up to 6x faster generation. Set--preset="speed"
when using the CLI, or passpreset="speed"
tofrom_pretrained
when using the Python API
There are two installation flavors of local-gemma
, which you can select depending on your use case:
pipx
- Ideal for CLIpip
- Ideal for Python (CLI + API)You can chat with the Gemma-2 through an interactive session by calling:
local-gemma
[!TIP] Local Gemma will check for a Hugging Face "read" token to download the model. You can follow this guide to create a token, and pass it when prompted to log-in. If you're new to Hugging Face and never used a Gemma model, you'll also need to accept the terms at the top of this page.
Alternatively, you can request a single output by passing a prompt, such as:
local-gemma "What is the capital of France?"
By default, this loads the Gemma-2 9b it model. To load the 2b it or 27b it
models, you can set the --model
argument accordingly:
local-gemma --model 2b
Local Gemma-2 will automatically find the most performant preset for your hardware, trading-off speed and memory. For more
control over generation speed and memory usage, set the --preset
argument to one of four available options:
You can also control the style of the generated text through the --mode
flag, one of "chat", "factual" or "creative":
local-gemma --model 9b --preset memory --mode factual
Finally, you can also pipe in other commands, which will be appended to the prompt after a \n
separator
ls -la | local-gemma "Describe my files"
To see all available decoding options, call local-gemma -h
.
[!NOTE] The
pipx
installation method creates its own Python environment, so you will need to use thepip
installation method to use this library in a Python script.
Local Gemma-2 can be run locally through a Python interpreter using the familiar Transformers API. To enable a preset,
import the model class from local_gemma
and pass the preset
argument to from_pretrained
. For example, the
following code-snippet loads the Gemma-2 9b model with the "memory" preset:
from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer
model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-9b", preset="memory")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model_inputs = tokenizer("The cat sat on the mat", return_attention_mask=True, return_tensors="pt")
generated_ids = model.generate(**model_inputs.to(model.device))
decoded_text = tokenizer.batch_decode(generated_ids)
When using an instruction-tuned model (prefixed by -it
) for conversational use, prepare the inputs using a
chat-template. The following example loads Gemma-2 2b it model
using the "auto" preset, which automatically determines the best preset for the device:
from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer
model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-2b-it", preset="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)
generated_ids = model.generate(**model_inputs.to(model.device), max_new_tokens=1024, do_sample=True)
decoded_text = tokenizer.batch_decode(generated_ids)
Local Gemma-2 provides three presets that trade-off accuracy, speed and memory. The following results highlight this trade-off using Gemma-2 9b with batch size 1 on an 80GB A100 GPU:
Mode | Performance* | Inference Speed (tok/s) | Memory (GB) |
---|---|---|---|
exact | 73.0 | 17.2 | 18.3 |
speed (CUDA-only) | 73.0 | 62.0 | 19.0 |
memory | 72.1 | 13.8 | 7.3 |
memory_extreme | 72.1 | 13.8 | 7.3 |
While an 80GB A100 places the full model on the device, only 3.7GB is required with the memory_extreme
preset. See the
section Preset Details for details.
*Zero-shot results averaged over Wino, ARC Easy, Arc Challenge, PIQA, HellaSwag, MMLU, OpenBook QA.
Mode | 2b Min Memory (GB) | 9b Min Memory (GB) | 27b Min Memory (GB) | Weights dtype | CPU Offload |
---|---|---|---|---|---|
exact | 5.3 | 18.3 | 54.6 | bf16 | no |
speed (CUDA-only) | 5.4 | 19.0 | 55.8 | bf16 | no |
memory | 3.7 | 7.3 | 17.0 | int4 | no |
memory_extreme | 1,8 | 3.7 | 4.7 | int4 | yes |
memory_extreme
implements CPU offloading through
🤗 Accelerate, reducing memory requirements down to the largest layer
in the model (which in this case is the LM head).
Local Gemma-2 is a convenient wrapper around several open-source projects, which we thank explicitly below:
And last but not least, thank you to Google for the pre-trained Gemma-2 checkpoints, all of which you can find on the Hugging Face Hub.