tensorflow / models

Models and examples built with TensorFlow
Other
77.05k stars 45.77k forks source link

max_seq_length doesn't change the model dimension #10834

Closed SohaKhazaeli closed 1 year ago

SohaKhazaeli commented 1 year ago

Prerequisites

Please answer the following questions for yourself before submitting an issue.

1. The entire URL of the file you are using

https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert

2. Describe the bug

When I follow this tutorial: max_seq_length = 128 in one cell then when the encoder and classifier is being built, in the later cells when you plot_model for bert_classifier and bert_encoder, you get the starting dimension (e.g. 512). The expected behavior is the dimension should reflect the assigned max_seq_length = 128 not the loaded pertained model

3. Steps to reproduce

Follow the https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert

4. Expected behavior

The expected behavior is the dimension should reflect the assigned max_seq_length = 128 not the loaded pertained model

5. Additional context

Include any logs that would be helpful to diagnose the problem.

6. System information

TensorFlow 2.10

saberkun commented 1 year ago

max_seq_length only controls: packer = tfm.nlp.layers.BertPackInputs( seq_length=max_seq_length, special_tokens_dict = tokenizer.get_special_tokens_dict()) max_seq_length does not affect the TF-Hub setup. It only affects data.

SohaKhazaeli commented 1 year ago

Thanks @saberkun. If max_sequence_length could be set with max_seq_length, model would be much faster. The model is being trained already just on 128 word-piece, but the model doesn't have the flexibly to get the speed benefit. I have the exact same usecase. I used to exactly cut on max_seq_length=128 when I used to use the original BERT research repo with tf1. Now that I moved to this tensorflow BERT (tf-models-official) I cannot get smaller model by exporting the 0-128 elements of input tensors, since it is inside the model.

saberkun commented 1 year ago

max_seq_length affects only the position embedding: https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/bert_encoder.py#L136 When you have sequence length = 8, for example, it will do embedding lookup for the 8 ids. If you change the max_sequence_length of BertEncoder to 128 (512 x hidden size -> 128 x hidden size), only this variable will change. It does not affect model size much.

I don't follow "get smaller model by exporting the 0-128 elements of input tensors".

The model does not have a fixed input length. The sequence dimension is dynamic. https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/bert_encoder.py#L218 The expected inputs are [batch size, sequence length] and both are dynamic. If you provide sequences with length=8, you should expect the model to run faster. If you need to pad the model to take inputs with a fixed length, you need to do that before providing inputs to BERTEncoder.

SohaKhazaeli commented 1 year ago

Thanks @saberkun. 1- Could you find any reason why the BERT classifier in tf-models-official is slower than a BERT classifier based on BERT research repo with tf1. I'm experiencing more than 20 percent increase in response time.

when I was using the BERT research repo I exported the BERT classifier with MAX_SEQ_LENGTH=128 def serving_input_fn(): input_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_ids') input_mask = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_mask') segment_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='segment_ids') input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, })() return input_fn

after that my model was like

saved_model.pb -- 764.4 KB variables |-----variables.data-00000-of-00001----239.4 MB |-----variables.index----------------------8.0 KB

after exporting the BERT classifier from tf-models-official, I get: saved_model.pb -- 6.8 MB variables |-----variables.data-00000-of-00001----239.5 MB |-----variables.index----------------------11.6 KB

I understand that the BERT in tf-models-official handles tokenization inside the model using WordPieceFastTokenizer, so probably larger saved_model.pb is because of that. But I don't know why it is significantly slower than the old repo.

2- To make up the increase in response time, I tried using tensorRT with FP16 and FP 32. I didn't gain any speed improvement. I don't know if there is a good resource/tutorial for optimizing tf-models-official BERT optimization.

saberkun commented 1 year ago
  1. The models have been benchmarked on TPU XLA. We see the model runs faster because we avoided the reshapes with tf.einsum. However, we did not compare with other runtime. Overall, to optimize, I recommend use XLA. In TF2, when you export a savedmodel with tf.function, you can set jit_compile=True: https://www.tensorflow.org/api_docs/python/tf/function.
  2. Because the dot general of XLA should be optimized for both GPU and TPU, hope this can accelerate the inference. For FP16, it should run faster. However, you might fail to enable the mix precision correctly.

Lastly, the TF2 model also takes integer ids as inputs. The tokenizer is not added unless you use the tf.hub module that has preprocessing built-in, which is not a fair comparison then.

SohaKhazaeli commented 1 year ago

Thanks @saberkun. 1- I tried adding jit_compile=True to tf.function parameters while exporting. It exported, but when I tried to use the model for inference I've got this error:

InvalidArgumentError: Detected unsupported operations when trying to compile graph __inference_restored_function_body_27254[_XlaMustCompile=true,config_proto=6001324581131673121,executor_type=11160318154034397263] on XLA_GPU_JIT: _Arg (No registered '_Arg' OpKernel for XLA_GPU_JIT devices compatible with node {{node answers_list}} (OpKernel was found, but attributes didn't match) Requested Attributes: T=DT_STRING, _output_shapes=[[32]], _user_specified_name="answers_list", index=0){{node answers_list}} The op is created at: ....

2- Could you elaborate on number 2 3- The increase that I experience (20%-30%) is end to end, not just inference time of the model, in both case I consider tokenization in the response time. Implementing FastWordpieceTokenizer also should have improve the Tokenization time. but in reality end-to-end, it became slower on Tesla T4.

SohaKhazaeli commented 1 year ago

Thanks @saberkun. I separated the tokenizer part. Instead of exporting the input_processor and classifier together into the saved_model, I just exported the classifier this time. It made the prediction (including tokenization) around 42% faster. I tried your recommendation on jit_compile=True, it made it around 30% faster on top of it (altogether 59% faster) I don't know why exporting them together made it that slow. It's probably worth changing https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert, since it exports them together. I'll check TF_TRT as well to see if I get more speeding up or not.

laxmareddyp commented 1 year ago

Hi @SohaKhazaeli,

If the solution provided is working fine then we close this issue.

Thanks.

SohaKhazaeli commented 1 year ago

I tried TF_TRT on exported classifier, still no gain. I am happy with the XLA speedup for now though.

saberkun commented 1 year ago

We don’t have experience with TRT.

the XLA does not support string ops. You can have a tf.function with jit_compile for the transformer part. You probably already did that.

the tf hub of Bert in TF2 provides multiple interfaces.

google-ml-butler[bot] commented 1 year ago

Are you satisfied with the resolution of your issue? Yes No