microsoft / onnxruntime-training-examples

Examples for using ONNX Runtime for model training.
MIT License
310 stars 62 forks source link

Training a BERT model is failing on android mobile device #175

Closed SJ4949 closed 4 months ago

SJ4949 commented 10 months ago

I was trying to train a pretrained BERT model but failing as the below nodes are not implemented.

RuntimeError: C:\a_work\1\s\orttraining\orttraining\training_api\module.cc:175 onnxruntime::training::api::Module::Module [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Reshape(19) node with name 'Reshape__265'

Are we planning to include this, let me know the schedule?

wschin commented 10 months ago

Assume your model comes from PyTorch. If you're able to lower the operator set used in your model, it's possible to workaround this missing operator. Follow this example which uses operator set=18:

# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
export_output = torch.onnx.dynamo_export(
    model,
    *args,
    **kwargs,
    export_options=export_options)
export_output.save("my_dynamic_model.onnx")
prathikr commented 10 months ago

@SJ4949 did @wschin's suggestion resolve your issue?

SJ4949 commented 9 months ago

Yes that really helped. Thank you.

But proceeding futher, i am getting this error:

terminating with uncaught exception of type Ort::Exception: /onnxruntime_src/orttraining/orttraining/training_api/module.cc:538 onnxruntime::common::Status onnxruntime::training::api::Module::TrainStep(const std::vector &, std::vector &) [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape__316' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:40 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape &, onnxruntime::TensorShapeVector &, bool) size != 0 && (input_shape_size % size) == 0 was false. The input tensor cannot be reshaped to the requested shape. Input shape:{5,2,128,1,64}, requested shape:{10,-1,10} 01-17 12:02:56.690 24719 24866 E libc++abi:

@wschin @prathikr

prathikr commented 4 months ago

@SJ4949 can you confirm your issue is still present with latest ORT/torch?

prathikr commented 4 months ago

Closing issue as STALE