Closed XiaoyuZeng closed 4 months ago
Hi @XiaoyuZeng
Thanks for bringing this to our attention. There are a few things in our tutorials that are now out of date. We will update those asap.
For your convenience now, here's a few simple steps to get HSSM to work in a colab with a GPU backend:
!pip install hssm
or !pip install git+https://github.com/lnccbrown/HSSM.git
to install the dev version.float32
via:import hssm
hssm.set_floatX("float32")
That's all you need to do to get JAX with GPU backend to work. JAX is installed on colab by default with GPU enabled so there is no need to mess with that. There is no additional steps needed to run the computation on GPU. You simply need to specify a JAX-based sampler (set sampler="nuts_numpyro"
when calling model.sample()
).
Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] somehow I can not get access to the colab notebook in the main tutorial (https://lnccbrown.github.io/HSSM/tutorials/main_tutorial/, If you would like to run this tutorial on Google colab, please click this link.) Describe the solution you'd like A clear and concise description of what you want to happen, formulated as a solution to the above mentioned problem if applicable. I struggled to run hssm in colab through trial and error. but I figure a comprehensive tutorial designed for colab that enable gpu acceleration would be highly appreciated.
things I learned through this process, and may be included in the colab tutorial notebook:
Describe alternatives you've considered Please let us know about alternative solutions or workarounds you have attempted to get the feature to work for you already.
Additional context Please provide any additional context you may think is useful in understanding the issue the feature will address and / or the feature itself. Screenshots are a welcome aid!