Shivanandroy / simpleT5

simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.
MIT License
387 stars 62 forks source link

TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask In onnx_predict function #8

Closed farshadfiruzi closed 3 years ago

farshadfiruzi commented 3 years ago

Hello, when I run the fine-tuned mt5 model under onnx, I get the following error:

`TypeError Traceback (most recent call last)

in ----> 1 model.onnx_predict(text) ~\AppData\Roaming\Python\Python38\site-packages\simplet5\simplet5.py in onnx_predict(self, source_text) 469 """ generates prediction from ONNX model """ 470 token = self.onnx_tokenizer(source_text, return_tensors="pt") --> 471 tokens = self.onnx_model.generate( 472 input_ids=token["input_ids"], 473 attention_mask=token["attention_mask"], C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\grad_mode.py in decorate_context(*args, **kwargs) 26 def decorate_context(*args, **kwargs): 27 with self.__class__(): ---> 28 return func(*args, **kwargs) 29 return cast(F, decorate_context) 30 C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs) 1051 input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs 1052 ) -> 1053 return self.beam_search( 1054 input_ids, 1055 beam_scorer, C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs) 1788 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 1789 -> 1790 outputs = self( 1791 **model_inputs, 1792 return_dict=True, C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], [] **TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask**'` I tried to downgrade transformers and onnxruntime but the error still remains.
Shivanandroy commented 3 years ago

Which transformers/simpleT5 version are you using?

farshadfiruzi commented 3 years ago

I am using transformers=4.8.2 and simpleT5=0.1.1 Also, I tried newer version of transformers (4.9.0 and 4.9.1) but cant fix error.

Shivanandroy commented 3 years ago

The issue is fixed in the latest version. Install the latest version: pip install --upgrade simplet5

farshadfiruzi commented 3 years ago

It works perfect now. Thanks a lot.

radurevutchi commented 3 years ago

May I ask, how exactly did you fix this? I'm looking for the PR or code change which fixed it - trying to adapt this code to MBart and I'm getting the exact same error. @Shivanandroy @farshadfiruzi

Shivanandroy commented 3 years ago

Hi @radurevutchi , The current version of SimpleT5 only supports training/inference T5/mT5/byT5 models, Support for quantization and onnx runtime is dropped because of version conflict issues.

Below is what SimpleT5 offers:

from simplet5 import SimpleT5
model = SimpleT5()

model.from_pretrained("t5","t5-base")

model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
            eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
            source_max_token_len = 512, 
            target_max_token_len = 128,
            batch_size = 8,
            max_epochs = 5,
            use_gpu = True,
            outputdir = "outputs",
            early_stopping_patience_epochs = 0,
            precision = 32
            )

# load trained T5 model
model.load_model("t5","path/to/trained/model/directory", use_gpu=False)

# predict
model.predict("input text for prediction")

If you want to adapt it for mBart or any other models, I will encourage you to write separate methods for quantization and onnx support in addition to training method. How to export your model to onnx: https://huggingface.co/transformers/serialization.html