tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.3k stars 3.47k forks source link

Generating text from arbitrary input string #303

Open ghego opened 6 years ago

ghego commented 6 years ago

I have a language model trained on a very large corpus. I can input any input string in the t2t_decoder interactive mode and it completely ignores it, it's not used by the model to generate text. Any suggestions?

Here are the detailed instructions on how to reproduce this bug:

1) train a language model. I trained mine on wikipedia data using:

PROBLEM=languagemodel_wiki_full32k
MODEL=attention_lm
HPARAMS=attention_lm_base

DATA_DIR=/mnt/data/t2t_data/
TMP_DIR=/mnt/data/t2t_datagen/
TRAIN_DIR=/mnt/data/t2t_train/$PROBLEM/$MODEL-$HPARAMS

WORKER_GPU=16

t2t-trainer \
  --data_dir=$DATA_DIR \
  --problems=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --hparams='batch_size=4096' \
  --output_dir=$TRAIN_DIR \
  --local_eval_frequency=0 \
  --worker_gpu=$WORKER_GPU \

2) launch the interactive t2t_decoder like this:

PROBLEM=languagemodel_wiki_full32k
MODEL=attention_lm
HPARAMS=attention_lm_base

DATA_DIR=/mnt/data/t2t_data/
TMP_DIR=/mnt/data/t2t_datagen/
TRAIN_DIR=/mnt/data/t2t_train/$PROBLEM/$MODEL-$HPARAMS

WORKER_GPU=0

BEAM_SIZE=1
ALPHA=0.6

t2t-decoder \
  --data_dir=$DATA_DIR \
  --problems=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --hparams='sampling_method=random' \
  --output_dir=$TRAIN_DIR \
  --decode_beam_size=$BEAM_SIZE \
  --decode_alpha=$ALPHA \
  --worker_gpu=$WORKER_GPU \
  --local_eval_frequency=0 \
  --decode_interactive \

it should display this:

INTERACTIVE MODE  num_samples=1  decode_length=100
  it=<input_type>     ('text' or 'image' or 'label', default: text)
  pr=<problem_num>    (set the problem number, default: 0)
  in=<input_problem>  (set the input problem number)
  ou=<output_problem> (set the output problem number)
  ns=<num_samples>    (changes number of samples, default: 1)
  dl=<decode_length>  (changes decode length, default: 100)
  <source_string>                (decode)
  q                   (quit)

at this point try changing the decode_lenght or giving an arbitrary source_string, the output is always the same.

lukaszkaiser commented 6 years ago

You need to make sure that self.has_input is false here: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L196

Currently that's taken from the Problem class and it's input modalities, so the setting can be wrong. Could you try to manually set it to False here: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L145

Does it work then?

ghego commented 6 years ago

I've just changed has_input to false like this:

  @property
  def has_input(self):
    return False #self._problem_hparams.input_modality

and it crashes with this error:

 Traceback (most recent call last):
  File "/bin/t2t-decoder", line 6, in <module>
    exec(compile(open(__file__).read(), __file__, 'exec'))
  File "/tensor2tensor/tensor2tensor/bin/t2t-decoder", line 93, in <module>
    tf.app.run()
  File "/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/tensor2tensor/tensor2tensor/bin/t2t-decoder", line 83, in main
    decoding.decode_interactively(estimator, decode_hp)
  File "/tensor2tensor/tensor2tensor/utils/decoding.py", line 309, in decode_interactively
    for result in result_iter:
  File "/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 340, in predict
    model_fn_lib.ModeKeys.PREDICT)
  File "/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 615, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/tensor2tensor/tensor2tensor/utils/model_builder.py", line 344, in wrapping_model_fn
    return model_fn(model, features, mode, hparams, **kwargs)
  File "/tensor2tensor/tensor2tensor/utils/model_builder.py", line 167, in model_fn
    max_idx=len(hparams.problems) - 1)
  File "/tensor2tensor/tensor2tensor/utils/input_fn_builder.py", line 187, in cond_on_index
    flat_out = wrapped_fn()
  File "/tensor2tensor/tensor2tensor/utils/input_fn_builder.py", line 182, in wrapped_fn
    out = fn(cur_idx)
  File "/tensor2tensor/tensor2tensor/utils/model_builder.py", line 123, in nth_model
    decode_length=decode_hp.extra_length)
  File "/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 207, in infer
    last_position_only)
  File "/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 352, in _greedy_infer
    logits = tf.zeros((batch_size, 0, 1, 1, target_modality.top_dimensionality))
UnboundLocalError: local variable 'batch_size' referenced before assignment
ghego commented 6 years ago

I've also tried setting if not self.has_input to True in T2T_model and I get the exact same error

ghego commented 6 years ago

ok, so the problem is due to the fact that the if statement on line 335 in t2t_model does not define batch_size, which is later required in the same function at line 352

ghego commented 6 years ago

Also, the model does seem to have input. Adding

 p_hparams     = hparams.problems[problem_id]
  tf.logging.info("DEBUG input_modality: %r" % p_hparams.input_modality)
  has_input     = "inputs" in p_hparams.input_modality
  vocabulary    = p_hparams.vocabulary["inputs" if has_input else "targets"]
  tf.logging.info("DEBUG has_input: %r" % has_input)

in decoding._interactive_input_fn yields:

INFO:tensorflow:DEBUG input_modality: {'inputs': ('symbol', 32599)}
INFO:tensorflow:DEBUG has_input: True
ghego commented 6 years ago

ok, more on this. I can run the code if with the these two changes:

1) set has_inputs to return False in wiki.py

2) add batch_size=1 in t2t_model.py, just before line 336.

With these two changes the decoder correctly takes the input string and reproduces it, however it truncates the output at the exact same length as the input. I suspect this is due to the fact that the function log_decode_results in decoding.py truncates the output string at the first <EOS> which is copied from the input string.

rsepassi commented 6 years ago

Have you tried this again recently? Is it still an issue?

ghego commented 6 years ago

Have not tried this recently