Open Pengjie-W opened 4 months ago
Hello @Pengjie-W, If I am not wrong, using a single decoder makes the inference much slower. I am not sure to make another version for this very soon, so I suggest you the following workaround. You can overload the class OnnxPredictor and remove the lines involving decoder_with_past and adapt the generating loop. For example, you can replace the lines
if past_key_values is None:
out_decoder = self.decoder.run(None, {'input_ids': input_ids, 'encoder_hidden_states': out_encoder})
logits = out_decoder[0]
past_key_values = {'past_key_value_input_' + str(k): out_decoder[k + 1] for k in
range(len(out_decoder[1:]))}
else:
out_decoder = self.decoder_with_past.run(None, {'input_ids': input_ids[:, -1:],
**past_key_values})
logits = out_decoder[0]
past_key_values = {'past_key_value_input_' + str(i): pkv for i, pkv in enumerate(out_decoder[1:])}
with
out_decoder = self.decoder.run(None, {'input_ids': input_ids, 'encoder_hidden_states': out_encoder})
logits = out_decoder[0]
Then delete the file decoder_with_past.onnx. Hope it helps you. Greetings, Ambroise
Can I just use one decoder.onnx instead of one decoder.onnx and decoder_with_past.onnx? It takes up more space. If possible, can you modify and provide a version of the code. Thank you very much.