tensorflow / models

Models and examples built with TensorFlow
Other
77.16k stars 45.76k forks source link

Expose `mlm` logits in Albert TF2 Hub Model #9613

Closed eddie-scio closed 3 years ago

eddie-scio commented 3 years ago

Prerequisites

Please answer the following question for yourself before submitting an issue.

1. The entire URL of the file you are using

https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_encoder.py

2. Describe the feature you request

I have been using https://tfhub.dev/google/albert_base/3 with the mlm signature to get mlm_logits for domain adaptation on the base Albert LM (training on ML Engine TPU). I'm trying to upgrade to TF2.0, so I found the TF2 Hub model here, but looking at its outputs, I only see ['sequence_output', 'encoder_outputs', 'default', 'pooled_output'], and no MLM. I think this is understandable given this model is framed as an encoder, but this means I also can't upgrade my workflow to TF2.0.

What's the TF2.0 supported route for my workflow?

3. Additional context

4. Are you willing to contribute it? (Yes or No)

Yes

aichendouble commented 3 years ago

Hi,

The TF2 hub model actually has a .mlm signature (sorry that we have not updated the docs yet). Please see the usage below:

albert = hub.load("https://tfhub.dev/tensorflow/albert_en_base/2")

bert_mlm_layer = hub.KerasLayer(albert.mlm, trainable=True)

mlm_inputs = dict(
      input_word_ids=input_word_ids,    # Shape [batch, seq_length], dtype=int32
      input_mask=input_mask,            # Shape [batch, seq_length], dtype=int32
      input_type_ids=input_type_ids,    # Shape [batch, seq_length], dtype=int32
      masked_lm_positions=mlm_positions # Shape [batch, num_predictions],
                                        # dtype=int32
  )
  mlm_outputs = bert_mlm_layer(mlm_inputs)
  assert mlm_outputs.keys() == {
    "pooled_output",   # Shape [batch_size, width], dtype=float32
    "sequence_output", # Shape [batch_size, seq_length, width], dtype=float32
    "mlm_logits"       # Shape [batch_size, num_predictions, vocab_size],
                       # dtype=float32
    "encoder_outputs", # A list of tensors, each of which corresponds to the
                       # the output of a transformer layer.
  }
eddie-scio commented 3 years ago

Wow, thanks for the quick update, that's awesome! I'll give this a whirl and make sure it works and report back.

google-ml-butler[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

google-ml-butler[bot] commented 3 years ago

Closing as stale. Please reopen if you'd like to work on this further.