cached_train has 251 429 rows & cached_test 51 748 rows.
When we train this model with a batch_size=4096 without BertEmbeddingLayer.
Our training takes less than 30min on an AWS ml.g4dn.2xlarge instance (1 GPU 16GB).
Once BertEmbeddingLayer is included, it is impossible to train the model with batch_size=4096: OOM Killed.
With a batch_size=2048, the tensorflow ETA is estimated at 500h.
Using a more powerful ml.p3.2xlarge machine (1 GPU GPUs-V100 & 8cpu) does not reduce ETA.
We also tested performing the tokenizer operation before the fit step but this did not improve ETA.
Using tokenizers and hugging-face encoders does nothing better.
We are left with an option, which is not preferred because we want to have embedding at inference, which is to
calculate the embedding outside the model and use the result as features model like bellow:
We'd really appreciate your help,
Do you have any suggestions for improvements to obtain a suitable ETA? Or advice on the correct implementation of BERT in tfrs ?
That's the cruel truth, Bert is expensive to train. Basically, you have 3 choices:
Try a smaller Bert, like small-bert in this page, and this one. We have tried it before, I kind of remember it's trainable but slow
Use a higher config GPU machine, like 4 GPU p3.8xlarge, or even try distributed training in multiple machines (The tuning work could be tricky)
Pre-generate all the embeddings and load the embedding weights to an embedding layer. This can separate the cost of train bert on the fly. Same idea as your listed option. You can also fine-tuned the bert on your dataset before generate the embeddings, this generally can have better performance
Hello everyone,
We are trying to integrate pre-trained BERT embedding into our TFRS model. Our model is based on the same definition as https://www.tensorflow.org/recommenders/examples/basic_retrieval.
Where
BertEmbeddingLayer
is defined as:We decided to use BERT in the training step like that we don't have to compute embeddings at the time of inference.
We run the following training:
cached_train has 251 429 rows & cached_test 51 748 rows.
When we train this model with a batch_size=4096 without
BertEmbeddingLayer
. Our training takes less than 30min on an AWS ml.g4dn.2xlarge instance (1 GPU 16GB).Once
BertEmbeddingLayer
is included, it is impossible to train the model with batch_size=4096: OOM Killed. With a batch_size=2048, the tensorflow ETA is estimated at 500h. Using a more powerful ml.p3.2xlarge machine (1 GPU GPUs-V100 & 8cpu) does not reduce ETA.We also tested performing the tokenizer operation before the fit step but this did not improve ETA. Using tokenizers and hugging-face encoders does nothing better.
We are left with an option, which is not preferred because we want to have embedding at inference, which is to calculate the embedding outside the model and use the result as features model like bellow:
We'd really appreciate your help, Do you have any suggestions for improvements to obtain a suitable ETA? Or advice on the correct implementation of BERT in tfrs ?