kimiyoung / transformer-xl

Apache License 2.0
3.61k stars 762 forks source link

Tensor2Tensor compatibility #7

Open JohannesTK opened 5 years ago

JohannesTK commented 5 years ago

Thank you for such easy to read code & repo - can be seen that a lot of hard work has gone into it! Secondly, found your work from Sebastian Ruder NLP newsletter and as he put it as: "Peer review is an imprecise process and gems may sometimes fall through the cracks." Your work was under one of the gems and I totally agree!

Now specifically, I tried using wt103 in Tensor2Tensor and I'm getting an error of:

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key transformer/body/decoder/layer_0/ffn/conv1/bias not found in checkpoint
     [[node save/RestoreV2_1 (defined at /home/ubuntu/tensor2tensor/venv/lib/python3.5/site-packages/tensor2tensor/utils/decoding.py:586)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]

I suppose it comes from the wrong hparams I am using?

@registry.register_hparams
def transformer_xl():
  """Hparams for transformer-xl"""
  hparams = transformer.transformer_base()
  hparams.batch_size = 2048
  hparams.hidden_size = 4096
  hparams.filter_size = 3072
  hparams.num_hidden_layers = 18
  hparams.num_heads = 16
  hparams.max_length = 1024
  hparams.eval_drop_long_sequences = True
  return hparams

Tensor2Tensor transformer hparams

kimiyoung commented 5 years ago

Glad that you like our work.

Our codebase is not compatible with Tensor2Tensor at this point. There are two reasons: 1) The computational graph we built in Transformer-XL contains components that are not part of the standard Transformer in Tensor2Tensor, including the recurrence mechanism and the new relative positional encodings. 2) The scope names used in Tensor2Tensor are different from ours.

Therefore it is not possible to load a Transformer-XL checkpoint by simply modifying hyperparameters in Tensor2Tensor. PRs on compatibility with Tensor2Tensor are more than welcome.

JohannesTK commented 5 years ago

Thanks for clearing it up.

JohannesTK commented 5 years ago

Any plans for upcoming compatibility?