Open Rustleman opened 4 years ago
There seems to be an issue in the inference_rnn function where the inference_cell and generator_cell are connected together:
inference_rnn
https://github.com/ogroth/tf-gqn/blob/bc84f24b5c72ef4389d238d8e40b4d1678ff24fe/gqn/gqn_draw.py#L417-L421
It looks like the gradient flows through z_q.
z_q
Adding the line z_q = tf.stop_gradient(z_q) seems to improve the results when just the generator_rnn is used during testing.
z_q = tf.stop_gradient(z_q)
There seems to be an issue in the
inference_rnn
function where the inference_cell and generator_cell are connected together:https://github.com/ogroth/tf-gqn/blob/bc84f24b5c72ef4389d238d8e40b4d1678ff24fe/gqn/gqn_draw.py#L417-L421
It looks like the gradient flows through
z_q
.Adding the line
z_q = tf.stop_gradient(z_q)
seems to improve the results when just the generator_rnn is used during testing.