microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.47k stars 2.9k forks source link

[Training] Shape change is not sync when serialize graph to proto #19741

Open guyang3532 opened 7 months ago

guyang3532 commented 7 months ago

Describe the issue

In ort training, we do graph transformation/optimization for forwardmodel and invoke 'Model::Load(forwardmodel->ToProto(), gradientmodel, nullptr, *logger_)' to get the gradient_model. If shapes of args are changed in the transformation/optimization step, the change will not be seen in gradientmodel.graph. This is because shape change is saved in Graph::nodeargs which will not sync when serialize graph to proto.

To reproduce

I construct a graph as: image

the onnx file can be created by code:

  onnxruntime::Model original_model("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
                           {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
  onnxruntime::Graph& graph = original_model.MainGraph();
  TypeProto tensor_float;
  tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
  onnxruntime::NodeArg input_def("X", &tensor_float), inter_def("Y", &tensor_float), output_def("Z", &tensor_float);

  onnxruntime::Node* node = &graph.AddNode("node1", "Identity", "Identity operator", ArgMap{&input_def}, ArgMap{&inter_def});
  node->SetExecutionProviderType(kCpuExecutionProvider);
  onnxruntime::Node* node2 = &graph.AddNode("node2", "Identity", "Identity operator", ArgMap{&inter_def}, ArgMap{&output_def});
  node2->SetExecutionProviderType(kCpuExecutionProvider);
  ASSERT_STATUS_OK(Model::Save(original_model, "./test.onnx"));

and test code to reproduce the issue:

  constexpr const ORTCHAR_T* model_uri = "test.onnx";
  std::shared_ptr<Model> forward_model;
  ASSERT_STATUS_OK(Model::Load(model_uri, forward_model, nullptr, *logger_));
  Graph& graph = forward_model->MainGraph();
  NodeArg* arg = graph.GetNodeArg("Y");
  ASSERT_TRUE(arg->Shape() == nullptr);
  onnx::TensorShapeProto new_shape;
  new_shape.add_dim()->set_dim_value(2);
  new_shape.add_dim()->set_dim_value(3);
  arg->SetShape(new_shape); // set the shape, but it's not sync and can not be seen in gradient_model

  // This code snippet is to sync the shape change,
  // If uncommented, the shape change can be seen in the gradient graph and this test succeed.
  // but this is not done in current ort code.
  // graph.Set_is_loaded_from_model_file(false);
  // graph.SetGraphResolveNeeded();
  // Graph::ResolveOptions resolve_options;
  // ASSERT_STATUS_OK(graph.Resolve(resolve_options));

  std::shared_ptr<Model> gradient_model;
  ASSERT_STATUS_OK(Model::Load(forward_model->ToProto(), gradient_model, nullptr, *logger_));
  Graph& gradient_graph = gradient_model->MainGraph();
  NodeArg* gradient_arg = gradient_graph.GetNodeArg("Y");
  ASSERT_TRUE(gradient_arg->Shape() != nullptr); // The shape change is not seen in the gradient graph and this will fail

Urgency

No response

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

ed550b5fe5aa41e182db84d2b2f2fb768121fd7a

PyTorch Version

2.2.0

Execution Provider

Default CPU

Execution Provider Library Version

No response

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.