Open ghego opened 7 years ago
You need to make sure that self.has_input is false here: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L196
Currently that's taken from the Problem class and it's input modalities, so the setting can be wrong. Could you try to manually set it to False here: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L145
Does it work then?
I've just changed has_input to false like this:
@property
def has_input(self):
return False #self._problem_hparams.input_modality
and it crashes with this error:
Traceback (most recent call last):
File "/bin/t2t-decoder", line 6, in <module>
exec(compile(open(__file__).read(), __file__, 'exec'))
File "/tensor2tensor/tensor2tensor/bin/t2t-decoder", line 93, in <module>
tf.app.run()
File "/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/tensor2tensor/tensor2tensor/bin/t2t-decoder", line 83, in main
decoding.decode_interactively(estimator, decode_hp)
File "/tensor2tensor/tensor2tensor/utils/decoding.py", line 309, in decode_interactively
for result in result_iter:
File "/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 340, in predict
model_fn_lib.ModeKeys.PREDICT)
File "/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 615, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/tensor2tensor/tensor2tensor/utils/model_builder.py", line 344, in wrapping_model_fn
return model_fn(model, features, mode, hparams, **kwargs)
File "/tensor2tensor/tensor2tensor/utils/model_builder.py", line 167, in model_fn
max_idx=len(hparams.problems) - 1)
File "/tensor2tensor/tensor2tensor/utils/input_fn_builder.py", line 187, in cond_on_index
flat_out = wrapped_fn()
File "/tensor2tensor/tensor2tensor/utils/input_fn_builder.py", line 182, in wrapped_fn
out = fn(cur_idx)
File "/tensor2tensor/tensor2tensor/utils/model_builder.py", line 123, in nth_model
decode_length=decode_hp.extra_length)
File "/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 207, in infer
last_position_only)
File "/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 352, in _greedy_infer
logits = tf.zeros((batch_size, 0, 1, 1, target_modality.top_dimensionality))
UnboundLocalError: local variable 'batch_size' referenced before assignment
I've also tried setting if not self.has_input
to True
in T2T_model
and I get the exact same error
ok, so the problem is due to the fact that the if
statement on line 335 in t2t_model
does not define batch_size
, which is later required in the same function at line 352
Also, the model does seem to have input. Adding
p_hparams = hparams.problems[problem_id]
tf.logging.info("DEBUG input_modality: %r" % p_hparams.input_modality)
has_input = "inputs" in p_hparams.input_modality
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
tf.logging.info("DEBUG has_input: %r" % has_input)
in decoding._interactive_input_fn
yields:
INFO:tensorflow:DEBUG input_modality: {'inputs': ('symbol', 32599)}
INFO:tensorflow:DEBUG has_input: True
ok, more on this. I can run the code if with the these two changes:
1) set has_inputs
to return False
in wiki.py
2) add batch_size=1
in t2t_model.py, just before line 336.
With these two changes the decoder correctly takes the input string and reproduces it, however it truncates the output at the exact same length as the input. I suspect this is due to the fact that the function log_decode_results
in decoding.py truncates the output string at the first <EOS>
which is copied from the input string.
Have you tried this again recently? Is it still an issue?
Have not tried this recently
I have a language model trained on a very large corpus. I can input any input string in the t2t_decoder interactive mode and it completely ignores it, it's not used by the model to generate text. Any suggestions?
Here are the detailed instructions on how to reproduce this bug:
1) train a language model. I trained mine on wikipedia data using:
2) launch the interactive
t2t_decoder
like this:it should display this:
at this point try changing the
decode_lenght
or giving an arbitrarysource_string
, the output is always the same.