lnccbrown / HSSM

Development of HSSM package
Other
76 stars 11 forks source link

request for comprehensive tutorial for colab that enable GPU acceleration #396

Closed XiaoyuZeng closed 4 months ago

XiaoyuZeng commented 5 months ago

Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] somehow I can not get access to the colab notebook in the main tutorial (https://lnccbrown.github.io/HSSM/tutorials/main_tutorial/, If you would like to run this tutorial on Google colab, please click this link.) Describe the solution you'd like A clear and concise description of what you want to happen, formulated as a solution to the above mentioned problem if applicable. I struggled to run hssm in colab through trial and error. but I figure a comprehensive tutorial designed for colab that enable gpu acceleration would be highly appreciated.

things I learned through this process, and may be included in the colab tutorial notebook:

Describe alternatives you've considered Please let us know about alternative solutions or workarounds you have attempted to get the feature to work for you already.

Additional context Please provide any additional context you may think is useful in understanding the issue the feature will address and / or the feature itself. Screenshots are a welcome aid!

digicosmos86 commented 5 months ago

Hi @XiaoyuZeng

Thanks for bringing this to our attention. There are a few things in our tutorials that are now out of date. We will update those asap.

For your convenience now, here's a few simple steps to get HSSM to work in a colab with a GPU backend:

  1. Install hssm through !pip install hssm or !pip install git+https://github.com/lnccbrown/HSSM.git to install the dev version.
  2. Import hssm and set float type to float32 via:
import hssm
hssm.set_floatX("float32")

That's all you need to do to get JAX with GPU backend to work. JAX is installed on colab by default with GPU enabled so there is no need to mess with that. There is no additional steps needed to run the computation on GPU. You simply need to specify a JAX-based sampler (set sampler="nuts_numpyro" when calling model.sample()).