rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

add progress bar for compiled inference loop #70

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

I added a progress bar to the compiled version of the sampler. This implements the solution suggested in #69: scan now loops over np.arange() as well as the keys. I've tried it on a simple example and it works. Note that this progress bar is really simple compared to tqdm so it might be worth adding stuff.

The thing I don't yet understand is that it seems to work fine for any number of chains (1 or several). I would have expected each chain to print out its progress bar. Any thoughts as to why this is?

jeremiecoullon commented 3 years ago

Also, I'm not sure why the tests fail; they seem unrelated to the changes I made. Thoughts?

rlouf commented 3 years ago

Great! Do you have any idea how the performance is affected?

The tests were failing because the version of JAX was pinned while the version of JAXlib wasn't and the latest update was incompatible with the previous version of JAX. I pinned the version, it should be fine now.

This works for all chains because they are currently all moved at the same time; we wait for all chains to finish before moving on to the next step, see state, info = jax.vmap(kernel, in_axes=(0, 0, 0))(keys, parameters, state). I will probably have to change this for empirical HMC and NUTS where chains will take a different number of leapfrog steps at each chain step, we'll improvise then :)

Have you tried updating a tqdm bar manually in the _print_consumer?

jeremiecoullon commented 3 years ago

Performance: I found no difference in perfomance on the simple example I tried (the linear regression example in the readme)

updating chains: that makes sense ok!

tqdm: this sort of works but has bugs. I tried the following on the simple example in this gist:

def _define_tqdm(arg, transform):
    global t
    t = tqdm(range(arg))

def define_tqdm(arg, result):
    host_callback.id_tap(_define_tqdm, arg, result=result),

And modified the _print_consumer function to update tqdm:

def _print_consumer(arg, transform):
    # iter_num, num_samples = arg
    t.update()

Then at the beginning of the sampler I run define_tqdm. This works perfectly the first time I run it but I run it a second time it starts acting weirdly (ie: prints the progress bar twice, or prints it on a new line). In particular it acts slightly differently every time. This might be to do with the fact that t is already defined, or something else

I'll keep on playing with it!

rlouf commented 3 years ago

Performance: I found no difference in perfomance on the simple example I tried (the linear regression example in the readme)

Ok that's good! In the end we might keep the uncompiled version because I noticed it was faster for smaller number of samples. I guess that's because of the time it takes to compile the full loop? I wonder if there is anything smarter to do here.

Then at the beginning of the sampler I run define_tqdm. This works perfectly the first time I run it but I run it a second time it starts acting weirdly (ie: prints the progress bar twice, or prints it on a new line). In particular it acts slightly differently every time.

Do you run t.close() at the end? Generally I find that tqdm sometimes has unpredictable behaviors :(

jeremiecoullon commented 3 years ago

Do you run t.close() at the end? Generally I find that tqdm sometimes has unpredictable behaviors :(

Yeah that was it :)

jeremiecoullon commented 3 years ago

So this seems to work now! Have a look at it.

About the non-compiled version: the basic linear regression example always seems faster when compiled when I run it (even for small number of samples such as 1000). But Maybe keeping the non-compiled version is a good idea anyways. We could make the compiled version the default version though?

Anyways we could make the UI for the compiled version look like the non-compiled one as tqdm seems to work now. Though maybe there should be a message such as "running compiled sampler" vs "running non-compiled sampler"?

rlouf commented 3 years ago

It looks really good now. I think it is ready to merge once the conflict has been resolved?

rlouf commented 3 years ago

Looks ready to merge. If you agree do you mind squashing your commits beforehand?

jeremiecoullon commented 3 years ago

Ok! Though I've never done it, and after googling this I'm still pretty confused tbh ;)

From where should I squash my commits, and how do I do that ? (I'll keep googling of course :) )

rlouf commented 3 years ago

You can use git's interactive rebase functionality: https://git-scm.com/book/en/v2/Git-Tools-Rewriting-History. It's a good trick to know. I just saw you can now do this from the Pull Request interface so I'll go ahead and merge.

Good job!

rlouf commented 3 years ago

Adressed one item in #65.