blei-lab / edward

A probabilistic programming language in TensorFlow. Deep generative models, variational inference.
http://edwardlib.org
Other
4.83k stars 759 forks source link

Fix errors in ReplicaExchangeMC #871

Closed YoshikawaMasashi closed 6 years ago

YoshikawaMasashi commented 6 years ago

Hi.

I implemented ReplicaExchangeMC in #865. But I found some errors and fixed.

multiple latent variables

A code which use multiple latent variables does not work, and this code does not work.

inference = ed.ReplicaExchangeMC({w: q_w, b: q_b}, {w: proposal_w, b:proposal_b},
                                                              data={X: X_train, y: y_train})
inference.run()

The cause for this is that callable which return a dict cannot be used in tf.case. The error occurs in this code. In this code replica_sample[i] is dict. (Exceptionally, It seem that a dict which has only one variable does work.)

sample_i = tf.case({tf.equal(new_replica_idx[candi], i): _stateful_lambda(
                    replica_sample[i])for i in range(self.n_replica)},
                   default=lambda: replica_sample[0], exclusive=True)

Therefore, I fixed this error. I implemented to calculate rates at each replica first. And tf.case is not used.

list for `latent_vars`

When I use a list for latent_vars, an error occurs.

inference = ed.ReplicaExchangeMC(latent_vars=[x],
                             proposal_vars={x: proposal_x})
inference.run()

In super(ReplicaExchangeMC, self).__init__(latent_vars, data), latent_vars(list) is converted to latent_vars(dict). But before this, latent_vars(dict) is needed. I also fixed this.

Thanks

YoshikawaMasashi commented 6 years ago

I added a unit test and refactor the code. When I perform a unit test that use a list for latent_vars of tf.float64, an error occurs. So I fixed this bug.

Is it better that I also add unit tests in the same way for other MC (MetropolisHastings and so on)?

dustinvtran commented 6 years ago

Great! Merging.

Is it better that I also add unit tests in the same way for other MC (MetropolisHastings and so on)?

Absolutely.