pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.5k stars 982 forks source link

DMM / DVBF examples #77

Closed ngoodman closed 6 years ago

ngoodman commented 7 years ago

Deep time-series models such as DMM and DVBF will be anchor models for the first release. We need to implement them in pyro, implement any training tricks needed, and replicate a few results.

@karalets has some pytorch code for DMM and @null-a has a webppl implementation, so it should be straightforward to implement.

null-a commented 7 years ago

FWIW, my webppl implementation is here.

martinjankowiak commented 7 years ago

first (semi-)working version of the deep markov model on the polyphonic music data is here:

https://github.com/uber/pyro/blob/martin-dev-dmm/examples/dmm.py

the goal is to reproduce the 'DKS' (deep kalman smoother) numbers for the JSB chorales dataset in 'Structured Inference Networks for Nonlinear State Space Models.' specifically to get something close to the NLL of 6.605 reported in the paper.

besides important details like making use of train/test/validation splits properly and fine-tuning the optimization, the major functionality lacking in the above code is mini-batch support (although it's worth noting that the experiments reported in the dmm paper were apparently done training one sequence at a time). at this point there are at least four possibilities for further development:

  1. limit oneself to batch_size = 1 [training would probably take 1-2 days]
  2. modify the JSB dataset to have fixed sequence lengths by lopping off some of the sequences (this has the obvious downside that it precludes a direct comparison to the results in the paper)
  3. figure out how to do variable length sequences in pyro in some hacky way
  4. figure out how to do variable length sequences in pyro in some more or less elegant way

thoughts? what are the minimum desiderata for release?

martinjankowiak commented 7 years ago

for completeness here are some of the current deltas (or rather possible deltas, since it's not always clear what was done in the paper):

  1. no dropout in rnn
  2. weight initialization probably different in various places
  3. no kl annealing
  4. no analytic kls
  5. no additional regularization (l2 etc.) [?]
  6. like in paulH's code, we parameterize sigma and not sigma^2
karalets commented 7 years ago

I will add analytic kldivs next to klqp as an option as outlined in the issue, that should take care of the one issue.

Also, irregular timeseries have an open issue. I tackled that with some masking, cab share that implementation when I get back, but I think we should tackle it with packed sequences from pytorch.

However, I think this is a separate issue and should be different.

Also: I think doing this on minibatches should be important, my code (which you have) for dkf's does that.

On Tue, Sep 12, 2017, 6:00 PM martinjankowiak notifications@github.com wrote:

for completeness here are some of the current deltas (or rather possible deltas, since it's not always clear what was done in the paper):

  1. no dropout in rnn
  2. weight initialization probably different in various places
  3. no kl annealing
  4. no analytic kls
  5. no additional regularization (l2 etc.) [?]
  6. like in paulH's code, we parameterize sigma and not sigma^2

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/uber/pyro/issues/77#issuecomment-329026701, or mute the thread https://github.com/notifications/unsubscribe-auth/ABVhL8lI8MhyAB2ZnEPuImtS0FXOtysdks5shykhgaJpZM4O-wPw .

ngoodman commented 7 years ago

@martinjankowiak what NLL are you getting as of now?

re: minibatches, we can presumably do minibatches with non-vectorized map_data? so the problem is that we want vectorization for speed? i would say the right solution will be the one closest to pytorch (even if we don't love their solution), which means futzing around with packed sequences. perhaps we should discuss as a group.

btw do we have a reference implementation (ie from the paper)? if so, how long does it take to train? are we much slower?

@null-a in your experience with the DMM in webppl, which of the delta's that @martinjankowiak lists above are most likely to be important?

martinjankowiak commented 7 years ago

@ngoodman after 200 epochs of training i get a (minus) elbo of about 9.4 on the training set. this is to be compared to the elbo of 7.0 they report on the test set after 2000 epochs. for comparison a vanilla RNN gets a NLL of 8.7. unfortunately after 200 epochs (about 6 hours!) i got nans. their implementation is in theano. no idea how slow or fast it is.

re: vectorization, i think some combination of packed sequences on the pytorch end along with something like a log_pdf_mask argument that's passed to sample and observe would be sufficient. thoughts?

karalets commented 7 years ago

Masking: When we did masking with Cem we just masked out vectors before passing them to sample or observe and it was fine if we iterated over single samples with irregular sequence lengths.

Testing: You can now report scores also on test data with the new klqp changes, that is effectively what they are for.

Pytorch dkf: I feel 2nats off is significant. I am in transit now, but either you or me could run my simple pytorch dkf code to check what scores it would get on exactly your data. If it has any advantage, that is a good sign because it shows that improvements to the Pyro code could be easy to get (analytic kldivs, maybe structure?).

Hyperparameters: I assume you use precisely their settings all over? You could try a slower learning rate to get more iterations, there may be many details off with settings in pytorch Adam or layer initialization that may map to slight changes in the setup.

On Wed, Sep 13, 2017, 2:58 PM martinjankowiak notifications@github.com wrote:

@ngoodman https://github.com/ngoodman after 200 epochs of training i get a (minus) elbo of about 9.4 on the training set. this is to be compared to the elbo of 7.0 they report on the test set after 2000 epochs. unfortunately after 200 epochs (about 6 hours!) i got nans. their implementation is in theano. no idea how slow or fast it is.

re: vectorization, i think some combination of packed sequences on the pytorch end along with something like a log_pdf_mask argument that's passed to sample and observe would be sufficient. thoughts?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/uber/pyro/issues/77#issuecomment-329195205, or mute the thread https://github.com/notifications/unsubscribe-auth/ABVhL5oZJAmG-FKZ1qkZR4PyksSa6C2lks5sh-2TgaJpZM4O-wPw .

karalets commented 7 years ago

But again: there is an issue for dynamic lengths, I strongly advise not to conflate it with the dkf issue.

On Wed, Sep 13, 2017, 3:09 PM Theofanis Karaletsos < theofanis.karaletsos@googlemail.com> wrote:

Masking: When we did masking with Cem we just masked out vectors before passing them to sample or observe and it was fine if we iterated over single samples with irregular sequence lengths.

Testing: You can now report scores also on test data with the new klqp changes, that is effectively what they are for.

Pytorch dkf: I feel 2nats off is significant. I am in transit now, but either you or me could run my simple pytorch dkf code to check what scores it would get on exactly your data. If it has any advantage, that is a good sign because it shows that improvements to the Pyro code could be easy to get (analytic kldivs, maybe structure?).

Hyperparameters: I assume you use precisely their settings all over? You could try a slower learning rate to get more iterations, there may be many details off with settings in pytorch Adam or layer initialization that may map to slight changes in the setup.

On Wed, Sep 13, 2017, 2:58 PM martinjankowiak notifications@github.com wrote:

@ngoodman https://github.com/ngoodman after 200 epochs of training i get a (minus) elbo of about 9.4 on the training set. this is to be compared to the elbo of 7.0 they report on the test set after 2000 epochs. unfortunately after 200 epochs (about 6 hours!) i got nans. their implementation is in theano. no idea how slow or fast it is.

re: vectorization, i think some combination of packed sequences on the pytorch end along with something like a log_pdf_mask argument that's passed to sample and observe would be sufficient. thoughts?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/uber/pyro/issues/77#issuecomment-329195205, or mute the thread https://github.com/notifications/unsubscribe-auth/ABVhL5oZJAmG-FKZ1qkZR4PyksSa6C2lks5sh-2TgaJpZM4O-wPw .

martinjankowiak commented 7 years ago

these are initial numbers. no fine tuning was done. it didn't even converge

null-a commented 7 years ago

in your experience with the DMM in webppl, which of the delta's that @martinjankowiak lists above are most likely to be important?

@ngoodman Most of those deltas apply to the webppl implementation too. The only differences between pyro and webppl appear to be weight decay and some weight init. (I didn't do ortho. init for the RNN either.) Not sure either of those explain the difference in behaviour?

ngoodman commented 7 years ago

sounds like perhaps just running it longer would get the NLL down....

null-a commented 7 years ago

FWIW, I ran the webppl implementation for 8 days (!), here's what progress on the elbo looked like:

One more possible delta: I ran this with gradient clipping, and IIRC the implementation used for the paper did something similar. (ETA: Again, not sure whether this was important.)

martinjankowiak commented 7 years ago

pushed a version with vectorized mini-batches. this required support for masking log_pdfs. see PR #119. how much do people hate such an approach? see also issue #34 for dynamic lengths.

ngoodman commented 7 years ago

@martinjankowiak could you say what you mean by masking log_pdfs? (give me a sketch so i don't have to go through code details?)

martinjankowiak commented 7 years ago

on the user end:

z = pyro.sample("z", ...., log_pdf_mask=my_mask)

somewhere in the internals:

if 'log_pdf_mask' in kwargs:
    return torch.sum(kwargs['log_pdf_mask'] * scores)
return torch.sum(scores)

currently the latter logic is replicated in the two distributions i modified, but it would be moved elsewhere if we wanted this in dev.

ngoodman commented 7 years ago

so what we'd really like is to be able to tell pyro.sample to not sample some rows of the output tensor?

martinjankowiak commented 7 years ago

yeah, although it needn't be entire rows or columns. and same for pyro.observe. and ideally it should be fast (so probably vectorized)

ngoodman commented 7 years ago

how do you see this interacting with pytorch packed sequences? presumably we'd like to be able to do a map_data over a packed sequence and have the Right Thing happen? maybe this suggests that pyro.observe should simply be packed sequence aware? (ie it's ok to sample extra choices but not ok to observe non-existent values....)

ps. this discussion probably belongs over in #34 -- sorry.

martinjankowiak commented 7 years ago

something like that sounds about right, but it might be too restrictive. for example, in the current case the rnn which informs the guide scans from right to left. so the packed sequence runs t = T, T-1, ..., 2, 1. however the sampling in the latent space runs in the opposite (i.e. forward) direction. so i guess with the proposed interface i'd need two different packed sequences?

martinjankowiak commented 7 years ago

having fixed an indexing bug and letting things run longer, i'm now down to test elbo's in the range of ~7.8 (still probably has a fair bit to go before it actually converges though). the target is ~7.0. kl annealing definitely seems to help, although i haven't made a careful comparison yet.

besides fine-tuning hyperparameters more, the two lowest hanging fruit on offer to (possibly) improve things further are probably:

(i) gradient clipping. this isn't mentioned in the paper but seems to be in their code (ii) weight regularization for the neural networks

i intend to try both. however, regarding (ii), what do people suppose is the cleanest approach? if i put gaussian priors on weights in the neural networks, the weights are stochastic and the resulting regularization includes L2 terms but also includes other terms (and introduces more variational parameters etc.). having fixed weights with just weight decay would require some strange mixed VI/MLE approach. probably the simplest thing would be some more or less hacky weight regularization such that the L2 weight terms aren't included in any sort of NLL estimates (which is presumably what they do?). is that also how webppl treats default regularization in nn's? as something outside the scope of mr bayes?

eb8680 commented 7 years ago

@martinjankowiak re: (ii), we can implement it on top of random_module and lift from #121 by automatically generating independent Gaussian priors and delta guides. We can hide the L2 terms from the NLL evaluation code with poutine.block.

martinjankowiak commented 7 years ago

what happens to the entropy term for the guide? an arbitrary constant that doesn't impact optimization?

eb8680 commented 7 years ago

The entropy term is zero for delta guides

martinjankowiak commented 7 years ago

i see. well that would work. when will that be ready to go?

eb8680 commented 7 years ago

Whenever @jpchen says #121 is ready to merge

martinjankowiak commented 7 years ago

status update