Closed hamzamerzic closed 1 month ago
Adding:
!pip install -U orbax
!pip install -U chex
in addition to the import changes suggested in the issue seems to make the gsm8k_eval.ipynb
imports work. I guess now we just need to wait for parameter and vocab checkpoints to become available.
Hi @hamzamerzic,
I understand you're facing issues running Colab notebooks that require the Gemma model and potentially using TPUs for acceleration. Here's a breakdown of the steps to get you started:
Install Latest Google Cloud TPU Libraries:
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
This command installs the latest libraries needed to interact with Google Cloud TPUs through JAX.
Verify TPU Availability:
import jax.tools.colab_tpu
import jax
jax.tools.colab_tpu.setup_tpu() print(jax.devices())
Running this code snippet will attempt to set up the TPU runtime and then print the available TPU devices in your Colab environment.
3. Install Gemma Model:
`!pip install git+https://github.com/google-deepmind/gemma.git`
Additional Resource:
- This Colab notebook [gist](https://colab.sandbox.google.com/gist/selamw1/04dc59eea35366e8b417d0ed501f53c9/gemma_inference_on_tpus.ipynb#scrollTo=F2gYnt9Uqj8M) demonstrating Gemma inference on TPUs might be helpful.
@selamw1 I liked your colab! that would be a nice recipe for the Gemma cookbook: goo.gle/gemma-cookbook
For future users, we have some more tutorials using JAX + Gemma here: ai.google.dev/gemma
Great thanks @gustheman
This PR added two new tutorials:
gemma_inference_on_tpu: Demonstrates basic inference with Gemma on TPUs.
gemma-data-parallel-inference-in-jax-tpu: Showcases data-parallel inference for faster processing on TPUs using JAX.
Hi @hamzamerzic,
Could you please confirm if this issue is resolved for you with the above comments ? Please feel free to close the issue if it is resolved ?
Thank you.
Closing this issue due to lack of recent activity, Please feel free reopen if this is still a valid request. Thank you!
I cannot get the Colabs to run on https://colab.research.google.com.
I had to replace
with
as the former repository does not exist.
I am still unable to get the versions to match for the code to run. Also, Google provides a free TPU tier for Colab so it would be great if the code could be adapted (or some notes included) to run it on TPU as well as GPU.
After fixing the gemma install and updating the JAX import as:
the code ends up failing with the following stack trace: