Ikomia-dev / onnx-donut

Export Donut model to onnx and run it with onnxruntime
Apache License 2.0
23 stars 4 forks source link

Just one decoder.onnx #1

Open Pengjie-W opened 4 months ago

Pengjie-W commented 4 months ago

Can I just use one decoder.onnx instead of one decoder.onnx and decoder_with_past.onnx? It takes up more space. If possible, can you modify and provide a version of the code. Thank you very much.

ambroiseb commented 4 months ago

Hello @Pengjie-W, If I am not wrong, using a single decoder makes the inference much slower. I am not sure to make another version for this very soon, so I suggest you the following workaround. You can overload the class OnnxPredictor and remove the lines involving decoder_with_past and adapt the generating loop. For example, you can replace the lines

            if past_key_values is None:
                out_decoder = self.decoder.run(None, {'input_ids': input_ids, 'encoder_hidden_states': out_encoder})
                logits = out_decoder[0]
                past_key_values = {'past_key_value_input_' + str(k): out_decoder[k + 1] for k in
                                   range(len(out_decoder[1:]))}

            else:
                out_decoder = self.decoder_with_past.run(None, {'input_ids': input_ids[:, -1:],
                                                                                **past_key_values})
                logits = out_decoder[0]
                past_key_values = {'past_key_value_input_' + str(i): pkv for i, pkv in enumerate(out_decoder[1:])}

with

            out_decoder = self.decoder.run(None, {'input_ids': input_ids, 'encoder_hidden_states': out_encoder})
            logits = out_decoder[0]

Then delete the file decoder_with_past.onnx. Hope it helps you. Greetings, Ambroise