thu-ml / zhusuan

A probabilistic programming library for Bayesian deep learning, generative models, based on Tensorflow
http://zhusuan.readthedocs.io
MIT License
2.2k stars 418 forks source link

Can't compute prior (local_log_prob) of a StochasticTensor inside tf.scan (in LSTM cell) #104

Closed jilljenn closed 6 years ago

jilljenn commented 6 years ago

Hi,

I tried to implement the bayesian_rnn from the docs. However, while trying to compute log_joint, I can't compute the log of prior log_pz because w is declared within a LSTM cell, so I get the following error:

Traceback (most recent call last):
  File "blstm.py", line 111, in <module>
    joint_ll = log_joint({'x': x, 'y_i': y_i, 'y_v': y_v})
  File "blstm.py", line 106, in log_joint
    log_pz, log_px_z = model.local_log_prob(['w', 'y_v'])  # Error
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/zhusuan/model/base.py", line 346, in local_log_prob
    ret.append(s_tensor.log_prob(s_tensor.tensor))
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/zhusuan/model/base.py", line 140, in log_prob
    return self._distribution.log_prob(given)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/zhusuan/utils.py", line 215, in _func
    return f(*args, **kwargs)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/zhusuan/distributions/base.py", line 303, in log_prob
    log_p = self._log_prob(given)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/zhusuan/distributions/univariate.py", line 180, in _log_prob
    return c - logstd - 0.5 * precision * tf.square(given - mean)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py", line 979, in binary_op_wrapper
    return func(x, y, name=name)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 8009, in sub
    "Sub", x=x, y=y, name=name)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1760, in __init__
    self._control_flow_post_processing()
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1769, in _control_flow_post_processing
    control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_util.py", line 263, in CheckInputFromValidContext
    raise ValueError(error_msg + " See info log for more details.")
ValueError: Cannot use 'model/scan/while/MatMul/Normal.sample/Squeeze' as input to 'Normal.log_prob/sub_1' because 'model/scan/while/MatMul/Normal.sample/Squeeze' is in a while loop. See info log for more details.

Up-to-run code: https://github.com/jilljenn/vae/blob/master/blstm.py#L106

How to make it?

thjashin commented 6 years ago

Hi, which version of TF do you work on? I think I had that example run on a earlier version.

jilljenn commented 6 years ago

TensorFlow 1.8.0. I will try later, thanks!

jilljenn commented 6 years ago

1.8 => fail 1.6 => fail (same error) 1.5, 1.4 => I'm getting another error, due to my tf.einsum operation:

ValueError: Cannot take the length of Shape with unknown rank.
jilljenn commented 6 years ago

In TensorFlow 1.4, the error becomes:

Traceback (most recent call last):
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
    return fn(*args)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
    status, run_metadata)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'Normal.log_prob/sub_1' has inputs from different frames. The input 'model/scan/while/MatMul/Normal.sample/Squeeze' is in frame 'model/scan/while/while_context'. The input 'model/zeros' is in frame ''.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "blstm.py", line 144, in <module>
    feed_dict={x: x_batch, y_i: y_i_batch, y_v: y_v_batch})
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 889, in run
    run_metadata_ptr)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1120, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1317, in _do_run
    options, run_metadata)
  File "/Users/jilljenn/code/vae/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1336, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'Normal.log_prob/sub_1' has inputs from different frames. The input 'model/scan/while/MatMul/Normal.sample/Squeeze' is in frame 'model/scan/while/while_context'. The input 'model/zeros' is in frame ''.
thjashin commented 6 years ago

Hi, I need to walk through your code to figure out what happened, but unfortunately I haven't got that much time yet (before the nips ddl). Could you take a look at this script, which corresponds to the example in the white paper.

csy530216 commented 6 years ago

Short answer: If you change the line joint_ll = log_joint({'x': x, 'y_i': y_i, 'y_v': y_v}) to joint_ll = log_joint({'w': tf.zeros([2 * 128 + 1, 4 * 128]), 'y_v': y_v}), the error will not occur. Long answer: In your code,

@zs.reuse('model')
def p_net(observed, seq_len):
    '''
    Encoder: p(x|z) = p(y_v|w)
    '''
    with zs.BayesianNet(observed=observed) as model:
        cell = BayesianLSTMCell(128, forget_bias=0.)
        # shape was [max_seq_len, nb_batches, nb_classes]
        h_list = bayesian_rnn(cell, x, y_i)
        item_features = tf.get_variable("item_features", shape=[nb_items, embedding_size, nb_classes],
                                        initializer=tf.truncated_normal_initializer(stddev=0.02))
        relevant_items = tf.nn.embedding_lookup(item_features, y_i, name="feat_items")
        logits = tf.tensordot(h_list, relevant_items, axes=[[2], [2]])  # That's not even the good shape but anyway
        _ = zs.Categorical('y_v', logits)  # shape of its local_log_prob = [max_seq_len, nb_batches]
                                           # because we already observe the true variable (y_v is in observed)
    return model

def log_joint(observed):
    model = p_net(observed, seq_len)
    # print('all', model._stochastic_tensors)  # w and y_v
    log_pz, log_px_z = model.local_log_prob(['w', 'y_v'])  # Error
    # log_px_z = model.local_log_prob('y_v')
    return log_pz + log_px_z  # Error
    # return log_px_z

joint_ll = log_joint({'x': x, 'y_i': y_i, 'y_v': y_v})

So in the zs.BayesianNet context, the StochasticTensors defined is w (when calling BayesianLSTMCell.__init__) and y_v. The parameter observed for zs.BayesianNet is to give some of these StochasticTensors observed values. Therefore, the line of joint_ll should be something like joint_ll = log_joint({'w': blabla, 'y_v': y_v}); you need not supply the value of x and y_i since there are no StochasticTensors named x or y_i in the BayesianNet. Moreover, since you want to compute log_joint of the BayesianNet, both the value of w and y_v should be given (or the value of w would be its samples according to self._w=zs.Normal(...)), to make the following codes reasonable:

def log_joint(observed):
    model = p_net(observed, seq_len)
    log_pz, log_px_z = model.local_log_prob(['w', 'y_v'])  # Error
    return log_pz + log_px_z  # Error

I think the reason why the error happens is: When calculating model.local_log_prob('w'), since when calculating p(w) we need a value of w, if value of w is not given in observed(as in your code), then the value of w will be set as its sample (In your code the value is sampled from zs.Normal('w', w_mean, std=1., group_ndims=2)). However, in your code self._w appears in the __call__ function of the BayesianLSTMCell, so w will be sampled in the loop body of while_loop in tf.scan. It seems that this will lead to the error when calculating model.local_log_prob('w'), but I am not clear about the reason.

I guess in the objective function what you want to optimize, the StochasticTensor should be integrated (taken integration). Then you could try variational inference using evidence lower bound by zs.variational.elbo; Please refer to JiaXin's code (need to change zs.sgvb to zs.variational.elbo and add the dataset loader in his code), or ZhuSuan's tutorial on Bayesian NN.

jilljenn commented 6 years ago

Thanks a ton!

I haven't got that much time yet (before the nips ddl)

We're all in this!

jilljenn commented 6 years ago

@csy530216 Can I still ask you questions?

Your approximate of the posterior (the variational function) in both bayesian_rnn.py and Bayesian NN tutorial does not seem to depend on the observed variables. Why is that? Is it normal behavior?

csy530216 commented 6 years ago

Yes. In VAE, the latent variable z is local (the number of latent variables is proportional to size of dataset). To accelerate the inference we use amortized inference, that is to make the variational distribution of z depend on the value of observed x. Then by learning a inference network, the experience between different data points can be shared. In principle in VAE we can also make the variational distribution independent of the observed variables like in Bayesian NN, which make the variational family more flexible, but the inference would be much less efficient (maybe there are other advantages and drawbacks). In Bayesian NN the latent variable W is global (shared between different data points), so the variational distribution need not depend on the observed variables.

jilljenn commented 6 years ago

Thanks! I guess it means q(W) reflects the posterior over the whole dataset D.

For my purposes, I need to make it depend on the datapoints. I will do my best :) Thanks for such a great library, great docs, and great customer service! 😄

csy530216 commented 6 years ago

Yes you are right. Thanks for your kind words :)