Open JingyaHuang opened 1 year ago
This is weird as I've verified that both query and key are float16 and on CUDA. Besides, the export of decoder without past doesn't have that issue. Will investigate further when I have the bandwidth...
[Update] I won't fix it unless large need from the community. As proper inference with ORT should go with subclasses of ORTModel
instead of ORTTrainer
. The inference part of ORTTrainer
is just for the fast test.
Will test the export again once this PR #749 is merged.
System Info
Who can help?
@JingyaHuang
Reproduction
Using
evaluate()
after mixed-precision training withORTTrainer
. But I think the problem is general for converting PyTorch models to ONNX with FP16 weights.Error message
Expected behavior
Decoder with past with fp16 weights can be successfully exported to ONNX model.
Contribution
I can take a closer look, but I don't have the bandwidth for the moment.