Closed farshadfiruzi closed 3 years ago
Which transformers/simpleT5 version are you using?
I am using transformers=4.8.2 and simpleT5=0.1.1 Also, I tried newer version of transformers (4.9.0 and 4.9.1) but cant fix error.
The issue is fixed in the latest version.
Install the latest version: pip install --upgrade simplet5
It works perfect now. Thanks a lot.
May I ask, how exactly did you fix this? I'm looking for the PR or code change which fixed it - trying to adapt this code to MBart and I'm getting the exact same error. @Shivanandroy @farshadfiruzi
Hi @radurevutchi , The current version of SimpleT5
only supports training/inference T5/mT5/byT5 models, Support for quantization and onnx runtime is dropped because of version conflict issues.
Below is what SimpleT5 offers:
from simplet5 import SimpleT5
model = SimpleT5()
model.from_pretrained("t5","t5-base")
model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
source_max_token_len = 512,
target_max_token_len = 128,
batch_size = 8,
max_epochs = 5,
use_gpu = True,
outputdir = "outputs",
early_stopping_patience_epochs = 0,
precision = 32
)
# load trained T5 model
model.load_model("t5","path/to/trained/model/directory", use_gpu=False)
# predict
model.predict("input text for prediction")
If you want to adapt it for mBart or any other models, I will encourage you to write separate methods for quantization and onnx support in addition to training method. How to export your model to onnx: https://huggingface.co/transformers/serialization.html
Hello, when I run the fine-tuned mt5 model under onnx, I get the following error:
`TypeError Traceback (most recent call last)