google-deepmind / PGMax

Loopy belief propagation for factor graphs on discrete variables in JAX
Apache License 2.0
124 stars 10 forks source link

some (very) small suggestions on the examples #3

Closed murphyk closed 3 months ago

murphyk commented 1 year ago
antoine-dedieu commented 3 months ago

Thank you @murphyk for these detailed comments!

  1. We have renamed the gmrf.py notebook to grid_mrf.py, to avoid confusion. Instead of training a GridMRF model from scratch, which was really slow, we now simply perturb the weights of the pretrained model. We only finetune the weights for 500 samples, starting from these perturbed weights. On a Colab GPU instance with 16 Gb of RAM, finetuning should now take around 5 minutes and should only require 11 Gb of RAM.

  2. We are now properly using jax.random with jax.vmap in all our notebooks

  3. We have removed these obscure lines.

  4. We have modified the Ising example and are now introducing it in our README as a tutorial, so that users can easily get started with PGMax.