tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.62k stars 3.51k forks source link

SYMBOL modality vocab size #68

Closed nuance closed 7 years ago

nuance commented 7 years ago

I'm trying to train a bytes-to-subwords model:

def problem(model_hparams):
    # This vocab file must be present within the data directory.
    vocab_filename = os.path.join(model_hparams.data_dir, 'vocab')

    source_encoder = text_encoder.ByteTextEncoder()
    target_encoder = text_encoder.SubwordTextEncoder(vocab_filename)

    p = problem_hparams.default_problem_hparams()
    p.input_modality = {"inputs": (registry.Modalities.SYMBOL, source_encoder.vocab_size)}
    p.target_modality = (registry.Modalities.SYMBOL, target_encoder.vocab_size)
    p.vocabulary = {
        "inputs": source_encoder,
        "targets": target_encoder,
    }

    return p

This fails catastrophically during model construction. It appears to work if the input & target modalities have the same vocab size (eg switching both to share the same SubwordTextEncoder) but fails if they differ in size. This appears to not be the case for other modalities (eg changing both the above to CLASS_LABEL appears to work).

vthorsteinsson commented 7 years ago

This looks a bit mixed up; for instance the p.vocabulary fields should probably be source_encoder for "inputs" and target_encoder for "targets"?

nuance commented 7 years ago

@vthorsteinsson you're right, this was a poorly reduced example (oops...). Edited the code to fix that issue.

rsepassi commented 7 years ago

What do you mean by "fails catastrophically"? Is there an error message you could share?

nuance commented 7 years ago

Here you go:

Traceback (most recent call last):
  File ".virt/bin/t2t-trainer", line 83, in <module>
    tf.app.run()
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File ".virt/bin/t2t-trainer", line 79, in main
    schedule=FLAGS.schedule)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/trainer_utils.py", line 240, in run
    run_locally(exp_fn(output_dir))
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/trainer_utils.py", line 532, in run_locally
    exp.train()
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 275, in train
    hooks=self._train_monitors + extra_hooks)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 665, in _call_train
    monitors=hooks)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 289, in new_func
    return func(*args, **kwargs)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 455, in fit
    loss = self._train_model(input_fn=input_fn, hooks=hooks)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 955, in _train_model
    model_fn_ops = self._get_train_ops(features, labels)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1162, in _get_train_ops
    return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1133, in _call_model_fn
    model_fn_results = self._model_fn(features, labels, **kwargs)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/trainer_utils.py", line 424, in model_fn
    len(hparams.problems) - 1)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/trainer_utils.py", line 751, in _cond_on_index
    return fn(cur_idx)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/trainer_utils.py", line 406, in nth_model
    features, skip=(skipping_is_on and skip_this_one))
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/t2t_model.py", line 396, in model_fn
    sharded_features["targets"], dp)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/modality.py", line 115, in targets_bottom_sharded
    return data_parallelism(self.targets_bottom, xs)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/expert_utils.py", line 294, in __call__
    outputs.append(fns[i](*my_args[i], **my_kwargs[i]))
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/models/modalities.py", line 94, in targets_bottom
    return self.bottom_simple(x, "shared", reuse=True)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/models/modalities.py", line 79, in bottom_simple
    var = self._get_weights()
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/models/modalities.py", line 67, in _get_weights
    0.0, self._body_input_depth**-0.5)))
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 1065, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 962, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 360, in get_variable
    validate_shape=validate_shape, use_resource=use_resource)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensor2tensor/utils/expert_utils.py", line 260, in DaisyChainGetter
    var = getter(name, *args, **kwargs)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 352, in _true_getter
    use_resource=use_resource)
  File "/Users/matt/.virt/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 682, in _get_single_variable
    "VarScope?" % name)
ValueError: Variable symbol_modality_31228_512/shared/weights_0 does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?
lukaszkaiser commented 7 years ago

Transformer by default tries to share embedding and softmax weights: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py#L289

This is of course not possible when embedding and softmax have different sizes. Could you try to set it to False instead of True in the above line, and let us know if it works? Thanks!

cshanbo commented 7 years ago

Ah, I had the same issue. And yes, if we uses separated vocabs with different size, the exception will be raised. My solution is, when using separated vocabs, we set the `` to 0, or, using the merged vocab instead.

The code will be something like:

def translate_zhen(model_hparams):
  """Chinese to English translation benchmark."""
  p = default_problem_hparams()
  # This vocab file must be present within the data directory.
  source_vocab_filename = os.path.join(model_hparams.data_dir, "vocab.zh")
  target_vocab_filename = os.path.join(model_hparams.data_dir, "vocab.en")
  source_token = text_encoder.TokenTextEncoder(vocab_filename=source_vocab_filename)
  target_token = text_encoder.TokenTextEncoder(vocab_filename=target_vocab_filename)

  p.input_modality = {"inputs": (registry.Modalities.SYMBOL, source_token.vocab_size)}
  p.target_modality = (registry.Modalities.SYMBOL,
                       target_token.vocab_size)

  if model_hparams.shared_embedding_and_softmax_weights == 1:
    p.input_modality = {"inputs": (registry.Modalities.SYMBOL,
                                   source_token.vocab_size + target_token.vocab_size)}
    p.target_modality = (registry.Modalities.SYMBOL,
                         source_token.vocab_size + target_token.vocab_size)
  p.vocabulary = {
      "inputs": source_token,
      "targets": target_token,
  }
  p.loss_multiplier = 1.4
  p.input_space_id = 16
  p.target_space_id = 4
  return p
lukaszkaiser commented 7 years ago

Or we could just set model_hparams.shared_embedding_and_softmax_weights = 0 every time the modalities aren't the same. How about making a PR with the zh-en data-set?