Open vsoesanto opened 2 years ago
@Ki6an I also tried by adding "output_hidden_states" = True in the onnx model's config. Also made sure this argument is added in model.generate() call, but still no luck. Any idea how I can retrieve the encoder/decoder hidden states?
sorry for the late reply,
you can get the hidden states
of the encoder easily just by sending in the input_ids
and attention mask
to the encoder as shown below
...
model = export_and_get_onnx_model(model_name)
encoder = model.encoder
hidden_state = encoder(input_ids, attention_mask)
but for the decoder, you need to make lots of changes. you can start by making changes here
decoder_output[0]
is the last_hidden_state
make it return that value as well
return (
decoder_output[0],
self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
decoder_output[1],
)
also, do the same changes for the decoder.
then, retrieve those values from ort session here and here
finally, pass those values here
as decoder_hidden_states=
I converted a locally saved T5 checkpoint to ONNX using FastT5:
I tested it for inference:
The hidden states are all None.
Is there any way that I can retrieve the hidden states for both encoder and decoder?