Closed maziyarpanahi closed 2 years ago
I've reproduced this issue - will discuss with the team what we can do to generally support SavedModel saving.
Hi @maziyarpanahi ! I've talked this over with the team and although we offer SavedModel
support for saving, it doesn't work with all models and we're not sure how possible it'll be to update all of them in the near future.
Can we ask what your use case for SavedModel
is, compared to just save_pretrained
or save_weights
? There may be another approach.
Hi @Rocketknight1
The use case is to serve the fine-tuned (or already uploaded model) in TensorFlow. The SavedModel format is the only way to avoid going from PyTorch to onnx-tf and then to TensorFlow.
There are some architectures that don't have any TF support which I understand and normally either wait or go through ONNX to TF. However, DebertaV2 supports saved_model
for the fill-mask and ForTokenClassification already. So I really thought this could be a bug if it only fails in DebertaV2ForSequenceClassification.
After some investigation, the cause is the different Dropout
being used. In the TokenClassification
model, standard Keras Dropout
is used. In the SequenceClassification
model, StableDropout
is used. This change is present in the original PyTorch models too, although I'm not sure why.
I don't think this is a bug with an easy fix, unfortunately - I'm not the model author so I don't want to change the Dropout type. However, you could probably make a local fork of transformers
and swap the StableDropout
for Dropout
, which would allow you to save the model as SavedModel
. I'll talk to the other team members and see what they think!
Thanks @Rocketknight1
This is a great help! I will make that change and try to fine-tune a base model on IMDB to see whether I can save it as a SavedModel and also share the stats just in case for quality control.
Hi @Rocketknight1
For future discussions, I have replaced StableDropout with Dropout, the issue was resolved in saving as SavedModel. Also, the eval from 3-4 trained models on IMDB showed me no difference between StableDropout and Dropout. So there are no tradeoffs when it comes to performance.
I can prepare a PR if you have decided to use Keras Dropout inside TFDebertaV2ForSequenceClassification
.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Environment info
transformers
version: 4.17.0Who can help
@LysandreJik
Models:
Information
Model I am using (Bert, XLNet ...):
kamalkraj/deberta-v2-xlarge
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
TFDebertaV2ForSequenceClassification
saved_model=True
to save as TensorFlow SavedModelReference: https://huggingface.co/docs/transformers/model_doc/deberta-v2#transformers.TFDebertaV2ForSequenceClassification
Expected behavior
It is expected to save
TFDebertaV2ForSequenceClassification
models as TensorFlow SavedModel similar toTFDebertaV2Model
models