Closed murphyk closed 3 months ago
Thank you @murphyk for these detailed comments!
We have renamed the gmrf.py
notebook to grid_mrf.py
, to avoid confusion. Instead of training a GridMRF model from scratch, which was really slow, we now simply perturb the weights of the pretrained model. We only finetune the weights for 500 samples, starting from these perturbed weights. On a Colab GPU instance with 16 Gb of RAM, finetuning should now take around 5 minutes and should only require 11 Gb of RAM.
We are now properly using jax.random
with jax.vmap
in all our notebooks
We have removed these obscure lines.
We have modified the Ising example and are now introducing it in our README as a tutorial, so that users can easily get started with PGMax.
from jax.example_libraries import optimizers
with optax. Also the term "GMRF" usually refers to Gaussian MRF, not grid :) Finally the learning code is super slow, even on an A100... (public colab)batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0))
is very obscure. What is the dict with a key called 'variables'? And in what sense are these evidence / potential 'updates', as opposed to just static things?