Closed bertini36 closed 7 years ago
apologies for the non-response! i was busy hacking away to submit an iclr paper for edward in time. ;)
thanks for looking into this. so from what i understand, this is to enable mini batch training for inference on the model wrappers, and with proper scaling of the log-densities. your proposals make sense. from your fork, it looks like you're properly subsetting the local latent variables, but you're not scaling the log-likelihood?
it's funny that you brought this up though, because part of the paper mentions how do to proper mini batch training. (you can see it in section 4.5 here)
Hi Dustin! Thanks for your reply!
We've been thinking and testing about what you're saying. We believe that certainly we have to scale the log-likelihood but we should also scale and index the q_log_prob[s] that corresponds to cn (local latent variables) of that batch in the build_score_loss_and_gradients. Don’t you think so?
if (z == 'cn'):
probs_cn = qz.log_prob(tf.stop_gradient(z_sample[z]))
probs_cn_batch = probs_cn[inference.model_wrapper.i_batch * inference.n_minibatch : (inference.model_wrapper.i_batch+1) * inference.n_minibatch]
q_log_prob[s] += tf.reduce_sum(probs_cn_batch) * inference.model_wrapper.N / inference.n_minibatch
else:
q_log_prob[s] += tf.reduce_sum(qz.log_prob(tf.stop_gradient(z_sample[z])))
Another thing is that when tf.run is called, in every iteration, is applying gradients to the whole vector cn. We believe we should only optimise the cn that corresponds to the actual batch. Do you agree? How could we do this?
Yeah I think this is where the use case of model wrappers sort of breaks down. It's clear that we need structure to deal with scaling and indexing, not only with respect to the likelihood but also the variational distribution's density. Handling this to generalize to many settings is very difficult.
I recommend looking at #327. It tries to solve these problems for the native Edward language.
Hi Dustin!
We are trying to implement the Gaussian Mixture Model with subsampling as shown in the PCA subsampling example. This is the result code:
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import edward as ed
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import tensorflow as tf
from edward.models import Dirichlet, Normal, InverseGamma, Categorical
from edward.util import get_dims, log_sum_exp, get_session
plt.style.use('ggplot')
N = 5000
M = 100
K = 2
D = 2
N_ITER = 200
N_SAMPLES = 10
def build_toy_dataset(N):
pi = np.array([0.5, 0.5])
mus = [[2, 2], [-2, -2]]
stds = [[0.1, 0.1], [0.1, 0.1]]
x = np.zeros((N, 2), dtype=np.float32)
for n in range(N):
k = np.argmax(np.random.multinomial(1, pi))
x[n, :] = np.random.multivariate_normal(mus[k], np.diag(stds[k]))
return x
def next_batch(M):
idx_batch = np.random.choice(N, M)
return x_data[idx_batch, :], idx_batch
ed.set_seed(42)
x_data = build_toy_dataset(N)
plt.scatter(x_data[:, 0], x_data[:, 1])
plt.axis([-3, 3, -3, 3])
plt.title('Simulated dataset')
plt.show()
# Probabilistic model definition
pi = Dirichlet(alpha=tf.constant([1.0]*K))
mu = Normal(mu=tf.zeros([K, D]), sigma=tf.ones([K, D]))
sigma = InverseGamma(alpha=tf.ones([K, D]), beta=tf.ones([K, D]))
c = Categorical(logits=ed.tile(ed.logit(pi), [M, 1]))
x = Normal(mu=tf.gather(mu, c), sigma=tf.gather(sigma, c))
# Variational model definition
var_qpi_alpha = tf.Variable(tf.random_normal([K]))
var_qmu_sigma = tf.Variable(tf.random_normal([K, D]))
var_qsigma_alpha = tf.Variable(tf.random_normal([K, D]))
var_qsigma_beta = tf.Variable(tf.random_normal([K, D]))
qpi_alpha = tf.nn.softplus(var_qpi_alpha)
qc_logits = tf.Variable(tf.zeros([N, K]))
qmu_mu = tf.Variable(tf.random_normal([K, D]))
qmu_sigma = tf.nn.softplus(var_qmu_sigma)
qsigma_alpha = tf.nn.softplus(var_qsigma_alpha)
qsigma_beta = tf.nn.softplus(var_qsigma_beta)
idx_ph = tf.placeholder(tf.int32, M)
qpi = Dirichlet(alpha=qpi_alpha)
qc = Categorical(logits=tf.gather(qc_logits, idx_ph))
qmu = Normal(mu=qmu_mu, sigma=qmu_sigma)
qsigma = InverseGamma(alpha=qsigma_alpha, beta=qsigma_beta)
x_ph = tf.placeholder(tf.float32, [M, D])
# Inference process
inference_global = ed.KLqp({pi: qpi, mu: qmu, sigma: qsigma}, data={x: x_ph, c: qc})
inference_local = ed.KLqp({c: qc}, data={x: x_ph, pi: qpi, mu: qmu, sigma: qsigma})
inference_global.initialize(scale={x: float(N) / M, c: float(N) / M}, var_list=[var_qpi_alpha, qmu_mu, var_qmu_sigma, var_qsigma_alpha, var_qsigma_beta], n_samples=N_SAMPLES)
inference_local.initialize(scale={x: float(N) / M, c: float(N) / M}, var_list=[qc_logits], n_samples=N_SAMPLES)
sess = ed.get_session()
init = tf.initialize_all_variables()
init.run()
for t in range(N_ITER):
x_batch, idx_batch = next_batch(M)
for _ in range(5):
inference_local.update(feed_dict={x_ph: x_batch, idx_ph: idx_batch})
info_dict = inference_global.update(feed_dict={x_ph: x_batch, idx_ph: idx_batch})
inference_global.print_progress(info_dict)
# Criticism
clusters = qc.sample().eval()
plt.scatter(x_data[:, 0], x_data[:, 1], c=clusters, cmap=cm.bwr)
plt.axis([-3, 3, -3, 3])
plt.title('Predicted cluster assignments')
plt.show()
I do not know if we are passing something over but this code gives us the following error:
Traceback (most recent call last):
File "new.py", line 92, in <module>
inference_local.update(feed_dict={x_ph: x_batch, idx_ph: idx_batch})
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/inferences/variational_inference.py", line 147, in update
_, t, loss = sess.run([self.train, self.increment_t, self.loss], feed_dict)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 915, in _run
feed_dict_string, options, run_metadata)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _do_run
target_list, options, run_metadata)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 985, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors.InvalidArgumentError: assertion failed: [] [Condition x < y did not hold element-wise: x = ] [Dirichlet/sample/Reshape:0] [1 2.4928351e-12] [y = ] [assert_less/y:0] [1]
[[Node: inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT, DT_STRING, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/Switch, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_0, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_1, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_2, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/Switch_1, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_4, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_5, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/Switch_2)]]
Caused by op u'inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert', defined at:
File "new.py", line 83, in <module>
inference_local.initialize(scale={x: float(N) / M, c: float(N) / M}, var_list=[qc_logits], n_samples=N_SAMPLES)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/inferences/klqp.py", line 63, in initialize
return super(KLqp, self).initialize(*args, **kwargs)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/inferences/variational_inference.py", line 104, in initialize
self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/inferences/klqp.py", line 119, in build_loss_and_gradients
return build_score_loss_and_gradients(self, var_list)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/inferences/klqp.py", line 537, in build_score_loss_and_gradients
z_copy = copy(z, dict_swap, scope=scope)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 160, in copy
value = copy(value, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 174, in copy
new_op = copy(op, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 206, in copy
elem = copy(x, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 174, in copy
new_op = copy(op, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 206, in copy
elem = copy(x, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 174, in copy
new_op = copy(op, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 206, in copy
elem = copy(x, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 174, in copy
new_op = copy(op, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 206, in copy
elem = copy(x, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 174, in copy
new_op = copy(op, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 197, in copy
elem = copy(x, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 206, in copy
elem = copy(x, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 174, in copy
new_op = copy(op, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 197, in copy
elem = copy(x, dict_swap, scope, True, copy_q)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 234, in copy
op_def)
File "/home/alberto/.virtualenvs/analytics/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1298, in __init__
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): assertion failed: [] [Condition x < y did not hold element-wise: x = ] [Dirichlet/sample/Reshape:0] [1 2.4928351e-12] [y = ] [assert_less/y:0] [1]
[[Node: inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT, DT_STRING, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/Switch, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_0, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_1, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_2, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/Switch_1, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_4, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/data_5, inference_140330950245520/2/assert_less/Assert/AssertGuard/Assert/Switch_2)]]
This error has already came out several times when we tested with different implementations of this model. What do you think it can be?
it's definitely worth investigating. i've been having issues with getting the model to work as well, even without data subsampling. most likely BBVI (at least without control variates) is not able to capture the mixture assignments well and requires working on the collapsed model.
one easy way to solve the above is to fix the latent probability and only do inference over the cluster means, standard deviations, and mixture assignments. (although the inference still gives poor results)
I have come across the same assert in my own attempt to port the GMM to Edward's native language. I have tracked the assert to the assert_less(x, 1) in ed.logit(x). I'll provide further updates if I figure out why samples from the Dirichlet proposal distribution are > 1.
The problem is not that the dirichlet sample is returning samples > 1, but that it is returning samples that are exactly 1 where as ed.logit
is only defined over the range (0, 1)
. I was hoping to use Categorical's p
keyarg instead of logits
keyarg to avoid calling ed.logit
at all, but that blows up with an unexpected keyarg exception on my system (I'm not sure if it's because of how Edward wraps tf's Categorical class or just because this parameter was introduced in a newer version of tensor flow than I have installed).
As an alternative, I tried clamping the input to ed.logit
to the range (0, 1) +/- epsilon
and that helps, but eventually after a some iterations, the dirichlet samples become NaNs. I can't explain that yet..
i've also noticed that this is the primary error. once the dirichlet returns a degenerate or nearly degenerate sample, i.e., a vector of zeros where one element is close to 1, the random variable's log prob returns NaN. this immediately breaks the algorithm causing more NaNs downstream.
import edward as ed
from edward.models import Dirichlet
ed.get_session()
x = Dirichlet(alpha=tf.ones(5))
x.log_prob(tf.constant([0.001, 0.001, 0.001, 0.001, 0.996])).eval()
## nan
Hi there! Thanks for your contribution with Edward!
We are trying to implement a Mixture Gaussian version without marginalizing the local variables (Cn) using MFVI approximation. For this purpose we have modified some Edward functionalities which you can find in this fork. There are only modifications in files inference.py and _variationalinference.py (You can see all changes searching the comment "ADDED").
Basically, we have modified Edward code to pass the batch index to our log_prob function in order to index the corresponding Cn tensor.
Modification 1: Main loop of variational inference process
Modification 2: To ensure the arrival order of the batches we modify the call to tf.train.slice_input_producer:
slices = tf.train.slice_input_producer(values, shuffle=False)
With these modifications we know
i_batch
in log_prob function and we can obtain the assignments to clusters of every point in the batch:cn[self.i_batch * n]
(n is the batch size).Modification 3: Changes in the [examples/local_tf_mixture_gaussian.py
Allthough this code runs properly, the algorithm doesn't converge in contrast to your marginalized version.
Do you think it may be due the sampling of local variables in the BBVI algorithm introduce a lot of variance? We have tried to increase the number of samples (
n_samples
) and we haven't observed any signifcant change.Are we missing something else?
Thank you for your attention