Ki6an / fastT5

⚡ boost inference speed of T5 models by 5x & reduce the model size by 3x.
Apache License 2.0
565 stars 72 forks source link

Unable to retrieve hidden_states #52

Open vsoesanto opened 2 years ago

vsoesanto commented 2 years ago

I converted a locally saved T5 checkpoint to ONNX using FastT5:

>>> from fastT5 import export_and_get_onnx_model
>>> from transformers import AutoTokenizer

>>> model_checkpoint = "path/to/checkpoint"
>>> model = export_and_get_onnx_model(model_name)

I tested it for inference:

>>> tokenizer = AutoTokenizer.from_pretrained(model_name)

>>> token = tokenizer(input_terms, max_length=512 * 2, padding=True, truncation=True, return_tensors='pt')

>>> out = model.generate(input_ids=token['input_ids'].to('cpu'),
                            attention_mask=token['attention_mask'].to('cpu'),
                            return_dict_in_generate=True,
                            max_length=512 * 2,
                            num_beams=1,
                            output_scores=True,
                            output_hidden_states=True)

>>> out.encoder_hidden_states
>>> out.decoder_hidden_states
(None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
...

>>> out
GreedySearchEncoderDecoderOutput(sequences=tensor([[  0, 119, 114, 102, 108, 111, 108, 125, 120, 112, 100, 101,  35,  53, ...
...
), , encoder_attentions=None, encoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, decoder_hidden_states=(None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None))

The hidden states are all None.

Is there any way that I can retrieve the hidden states for both encoder and decoder?

vsoesanto commented 2 years ago

@Ki6an I also tried by adding "output_hidden_states" = True in the onnx model's config. Also made sure this argument is added in model.generate() call, but still no luck. Any idea how I can retrieve the encoder/decoder hidden states?

Ki6an commented 2 years ago

sorry for the late reply,

you can get the hidden states of the encoder easily just by sending in the input_ids and attention mask to the encoder as shown below

...
model = export_and_get_onnx_model(model_name)
encoder = model.encoder

hidden_state = encoder(input_ids, attention_mask)

but for the decoder, you need to make lots of changes. you can start by making changes here

https://github.com/Ki6an/fastT5/blob/20441b33394e71f7612f39f228ecbe1925cd10ae/fastT5/onnx_models_structure.py#L52-L62

decoder_output[0] is the last_hidden_state

make it return that value as well

     return ( 
         decoder_output[0],
         self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)), 
         decoder_output[1], 
     ) 

also, do the same changes for the decoder.

then, retrieve those values from ort session here and here

finally, pass those values here

https://github.com/Ki6an/fastT5/blob/20441b33394e71f7612f39f228ecbe1925cd10ae/fastT5/onnx_models.py#L199

as decoder_hidden_states=