pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.53k stars 984 forks source link

Implement AIR model #78

Closed ngoodman closed 6 years ago

ngoodman commented 7 years ago

The attend-infer-repeat (AIR) model is a great example of stochastic recursion in deep probabilistic models. It is going to be one of the anchor examples for the first pyro release.

We need to implement it, implement any extensions that training relies on, and replicate (some) results from the paper.

Getting acceptably low variance for the gradient estimator term from the discrete RV will likely take some amount of rao-blackwellization and data dependent baselines. @martinjankowiak has started or planned these.

null-a commented 7 years ago

Getting acceptably low variance for the gradient estimator term from the discrete RV will likely take some amount of rao-blackwellization and data dependent baselines.

Agreed. I expect that it will be necessary to use map_data to weight LR terms by only those "costs" incurred by the associated data point.

I see work on this is underway, which is great! One thing I'm not clear about is how this will interact with sequence models, where the length of the sequence is stochastic, when writing the model to operate on entire mini-batches.

As a simple example, consider this:

def local_model(batch):
  n = batch.shape(0)

  # random choice for all data points
  x = pyro.sample(Normal(zeros([n,1]), ones([n,1])))

  # flip coins to decide which data points make second random choice
  p = pyro.sample(Bernoulli(0.5 * ones([n,1])))

  # make second random choice for a subset of the mini batch
  m = torch.sum(p)
  y = pyro.sample(Normal(zeros([m,1]), ones([m,1])))

  # combine first and second choices
  z = combine(x, y, p)

  pyro.observe(dist, z)

map_data(data, local_model)

The interesting feature of this is that the number of choices made per data point is not fixed. (This is similar but not identical to the case where the inputs are of differing lengths. Related #34, #67.)

Here's the bit I'm unclear about... I'm guessing that the Rao-Blackwellized map_data implementation will assume that a data point has a fixed position within the mini batch, in order to track the "cost" for each data point across multiple choices? However, that assumption doesn't hold for the example above -- since we only make the second choice for a subset of the data, a particular data point may appear in a different position across choices.

Will models of this form will be supported by the planned implementation, or is there a different way of expressing the model to make it fit? (Other than by using batch_size=1 and combining gradient estimates for a mini batch by hand.)

I suppose a direct approach to solving this would be to pass the information about which subset of data points a choice is been made for to sample, recovering the ability to track data points across choices. (Either as a separate argument, or encoded in the parameters in someway.) This is a bit fiddly though, so hopefully there's something better?

Eventually, I think we'd ideally like to write local models for a single data point and have the back end figure out the mini batch version, but this is tricky of course. (Probably not news, but here is some related work on the variable length input case: 1, 2.)

ngoodman commented 7 years ago

i discussed some related ideas with @eb8680 yesterday. if i'm understanding correctly this is about the vectorized version of map_data? (i.e. the webppl style version that maps the observation function over each data point in the batch should be ok?)

we'll document this more fully elsewhere, but my current thinking is that the functional map version should be the basic version that defines the correct behavior; vectorization is then an optimization that the implementation can try to do. the thought is that this optimization is achievable by overloading the tensor library to have a batch dim, while marking some tensor ops as "unsafe for vectorization"; if an observation function is trying to be vectorized but hits an unsafe op, it will bail out and fall back on the independent map version. does this make sense?

null-a commented 7 years ago

if i'm understanding correctly this is about the vectorized version of map_data?

Yes.

the webppl style version that maps the observation function over each data point in the batch should be ok?

Yes, OK in the sense that we can write the model down and inference will do the right thing. The question I had in my mind, but didn't make explicit, was whether the non mini batch version would be OK in terms of the performance goals of the project. I'll measure the performance of both implementations at some point.

does this make sense?

Yep, I think I get the general idea.

ngoodman commented 7 years ago

The question I had in my mind, but didn't make explicit, was whether the non mini batch version would be OK in terms of the performance goals of the project.

right: probably not. but doing it in two steps seems cleaner anyhow.

null-a commented 7 years ago

I'm working on this over on this branch.

I'll measure the performance of both implementations at some point.

An initial test suggests that the vectorized version is about 5 times faster than webppl style. This probably underestimates the difference we'll see in the end, since the particular vectorized implementation I measured this on was doing a lot of wasted computation that we could likely avoid. (Each sequence in the batch was padded with extra steps so that all sequences had the same length.)

eb8680 commented 7 years ago

@null-a now that #61 and #62 and #84 are merged are there any other Pyro features we need for this example?

null-a commented 7 years ago

@eb8680 The main thing is to use the independence info from map_data when building the dependency graph. (Unless that found its way in without me noticing.)

ngoodman commented 7 years ago

@null-a do you have a way to get a quick check of how the model is performing now (with neural baseline but sans mapdata independence)?

null-a commented 7 years ago

do you have a way to get a quick check of how the model is performing now (with neural baseline but sans mapdata independence)?

@ngoodman No. I don't have baselines implemented yet, and once I do I don't know how to check performance other than by running it, which isn't quick.

@martinjankowiak I'm still unclear about how an RNN can be used to output a baseline for each choice in a sequence. Any chance you could provide a rough sketch of how this would look?

martinjankowiak commented 7 years ago

@null-a with the code that's currently in dev, i don't think you can do that. can you provide a pseudo-code snippet that sketches out what you want? then i can see what set of changes would be required. doing this more or less elegantly may require changes in the interface

null-a commented 7 years ago

@martinjankowiak I think the idea would be to have an extra RNN (in addition to the inference net) that runs along the sequence, which is used to output a baseline value at each choice. After the choice, the sampled value would be used to produce a new hidden state. So focussing on a single choice in the sequence for a single data point, and ignoring the inference net, we might have something like:

baseline = some_nn(rnn_hid_state)
x = sample('x', dist, baseline=baseline)
new_rnn_hid_state = rnn(rnn_hid_state, embed(x))

I guess as long as we can package the whole baseline net up into a single torch module (so that all of its params are updated) then we can probably make it work with the current interface. That seems like it might be do-able, so I'll make an attempt at some point. Thanks.

null-a commented 7 years ago

Progress update: It looks like my implementation is really slow at present. I estimate it will take around a year to get to reasonable inferences out of the guide (optimizing on the CPU), and much longer to run optimization for as long as they did for the paper. (Assuming I'm interpreting the results in the paper correctly.)

So, I'll need a bunch of tricks to speed this up, and I guess that one of those will end up been the use of vectorized map_data. This will require Rao-Blackwellization/baselines for vectorized map_data, and perhaps parts of #34.

ngoodman commented 7 years ago

hmm.... we didn't have vectorization in webppl, so how did we get acceptable performance there (or did we never get that far)?

null-a commented 7 years ago

how did we get acceptable performance there (or did we never get that far)?

The latter.

karalets commented 7 years ago

Hi guys,

Regarding baselines. I wanted to throw in a chunk from the paper-appendix that maybe was missed here.

"I.5 Supervised learning For the baselines trained in a supervised manner we use the ground truth scene variables [z_1:N pres, z:N where, z1:N what] that underly the training scene images as labels and train a network of the same form as the inference network to maximize the conditional log likelihood of the ground truth scene variables given the image."

As such, they basically learn baselines in a supervised way (apparently).

I shot them an email to follow up, but wanted to highlight this here. Linking to #126 .

null-a commented 7 years ago

I wanted to throw in a chunk from the paper-appendix that maybe was missed here.

My understanding is that this is used on the 3D scene example and not the multi-mnist example I'm working on. (I'm mentioning this only to point out that I don't think this stands in the way of me reproducing the result I'm shooting for, and not to take anything away from the idea, which is interesting.)

null-a commented 7 years ago

Progress update:

It looks like my implementation is really slow at present.

I've now re-written the model in vectorized style. In order to optimize it, I've cobbled together an implementation of kl_qp that supports vectorized Rao-Blackwellization/baselines. This kl_qp is an ugly hack that works for this model but isn't very general. It has allowed me to get results out of this model, but it's nothing like the fully general thing that we need to implement eventually.

I've also switched to running this on a gpu, which together with vectorization makes things run a couple of orders of magnitude faster, making it usable.

Results so far: My goal is to replicate their first result on multi-mnist, and optimization of the pyro implementation seems to be working almost as hoped. I'm seeing similar progress on the elbo to that reported in the paper, the inference net is successfully picking out digits in the image, and reconstructions look reasonable.

The main snag is that rather than always avoiding use of the final time step (which is never necessary to explain the input) the guide is instead wasting the first time step, by using it to explain nothing. For example:

4 input images: inputs

Reconstructions, with vizualization of (some of the) latents: final-recon (Step one in red, two in green, three in blue.)

I don't yet understand why this is happening, but I'm working on it. (Any thoughts on this are welcome, of course!)

ngoodman commented 7 years ago

I've also switched to running this on a gpu, which together with vectorization makes things run a couple of orders of magnitude faster, making it usable.

woohoo!

The main snag is that rather than always avoiding use of the final time step (which is never necessary to explain the input) the guide is instead wasting the first time step, by using it to explain nothing.

fascinating. if you only give it two timesteps, then it uses the first appropriately? and you're sure you are counting in the right direction? ;)

This kl_qp is an ugly hack that works for this model but isn't very general. It has allowed me to get results out of this model, but it's nothing like the fully general thing that we need to implement eventually.

we (collectively) should make a plan for getting a clean version worked out and into dev, since i think we'll need it for release.

null-a commented 7 years ago

if you only give it two timesteps, then it uses the first appropriately?

Yeah, it appears so:

img-2-steps-recon

(though here the last step is still used unnecessarily when there's one digit.)

ngoodman commented 7 years ago

interesting! could it be as simple as decreasing the prior probability of recursion, to better encourage not using extra steps? (this still doesn't explain why the earlier result was punting on the first step, which is very odd.)

null-a commented 7 years ago

could it be as simple as decreasing the prior probability of recursion, to better encourage not using extra steps?

Yeah, I already made one change in that direction for the reason you suggest, but I could go further.

Related to this, yesterday I learned that for the paper they "annealed the success probability from a value close to 1 to either 1e−5 or 1e−10 depending on the dataset over the course of 100k training iterations". (I'd rather not have to do that though.)

ngoodman commented 7 years ago

useful blog post -- nice find!

annealing the success probability doesn't seem crazy (if a bit hacky). it would be nice if this were straightforward to do in pyro (it's the kind of tinkering with learning that we want to make accessible). i think it might be as simple as making success prob an arg to model, which then becomes an arg to kl.step, and changing it as we like over learning? if it's more complex than that, there's no need to implement now, but it'd be nice to think it through sometime.

null-a commented 7 years ago

i think it might be as simple as making success prob an arg to model, which then becomes an arg to kl.step, and changing it as we like over learning?

Yeah, I think it's as simple as that.

eb8680 commented 6 years ago

Closed by #259