google-deepmind / gemma

Open weights LLM from Google DeepMind.
http://ai.google.dev/gemma
Apache License 2.0
2.48k stars 311 forks source link

Colabs don't seem to work #10

Closed hamzamerzic closed 1 month ago

hamzamerzic commented 8 months ago

I cannot get the Colabs to run on https://colab.research.google.com.

I had to replace

!pip install https://github.com/deepmind/gemma

with

!pip install "git+https://github.com/google-deepmind/gemma.git"

as the former repository does not exist.

I am still unable to get the versions to match for the code to run. Also, Google provides a free TPU tier for Colab so it would be great if the code could be adapted (or some notes included) to run it on TPU as well as GPU.

After fixing the gemma install and updating the JAX import as:

!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

the code ends up failing with the following stack trace:

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-6-cb05cf1a7a98>](https://localhost:8080/#) in <cell line: 2>()
      1 import re
----> 2 from gemma import params as params_lib
      3 from gemma import sampler as sampler_lib
      4 from gemma import transformer as transformer_lib
      5 

3 frames
[/usr/local/lib/python3.10/dist-packages/gemma/params.py](https://localhost:8080/#) in <module>
     20 import jax
     21 import jax.numpy as jnp
---> 22 import orbax.checkpoint
     23 
     24 Params = Mapping[str, Any]

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/__init__.py](https://localhost:8080/#) in <module>
     17 import functools
     18 
---> 19 from orbax.checkpoint import checkpoint_utils
     20 from orbax.checkpoint import lazy_utils
     21 from orbax.checkpoint import test_utils

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_utils.py](https://localhost:8080/#) in <module>
     23 from jax.sharding import Mesh
     24 import numpy as np
---> 25 from orbax.checkpoint import type_handlers
     26 from orbax.checkpoint import utils
     27 

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py](https://localhost:8080/#) in <module>
     22 from etils import epath
     23 import jax
---> 24 from jax.experimental.gda_serialization import serialization
     25 from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
     26 import jax.numpy as jnp

ModuleNotFoundError: No module named 'jax.experimental.gda_serialization'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------
hamzamerzic commented 8 months ago

Adding:

!pip install -U orbax
!pip install -U chex

in addition to the import changes suggested in the issue seems to make the gsm8k_eval.ipynb imports work. I guess now we just need to wait for parameter and vocab checkpoints to become available.

selamw1 commented 7 months ago

Hi @hamzamerzic,

I understand you're facing issues running Colab notebooks that require the Gemma model and potentially using TPUs for acceleration. Here's a breakdown of the steps to get you started:

  1. Install Latest Google Cloud TPU Libraries: !pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html This command installs the latest libraries needed to interact with Google Cloud TPUs through JAX.

  2. Verify TPU Availability:

    
    import jax.tools.colab_tpu
    import jax

jax.tools.colab_tpu.setup_tpu() print(jax.devices())


Running this code snippet will attempt to set up the TPU runtime and then print the available TPU devices in your Colab environment. 

3. Install Gemma Model:
`!pip install git+https://github.com/google-deepmind/gemma.git`

Additional Resource:

- This Colab notebook  [gist](https://colab.sandbox.google.com/gist/selamw1/04dc59eea35366e8b417d0ed501f53c9/gemma_inference_on_tpus.ipynb#scrollTo=F2gYnt9Uqj8M) demonstrating Gemma inference on TPUs might be helpful.
gustheman commented 4 months ago

@selamw1 I liked your colab! that would be a nice recipe for the Gemma cookbook: goo.gle/gemma-cookbook

For future users, we have some more tutorials using JAX + Gemma here: ai.google.dev/gemma

selamw1 commented 4 months ago

Great thanks @gustheman

This PR added two new tutorials:

Gopi-Uppari commented 2 months ago

Hi @hamzamerzic,

Could you please confirm if this issue is resolved for you with the above comments ? Please feel free to close the issue if it is resolved ?

Thank you.

tilakrayal commented 1 month ago

Closing this issue due to lack of recent activity, Please feel free reopen if this is still a valid request. Thank you!