In ort training, we do graph transformation/optimization for forwardmodel and invoke 'Model::Load(forwardmodel->ToProto(), gradientmodel, nullptr, *logger_)' to get the gradient_model.
If shapes of args are changed in the transformation/optimization step, the change will not be seen in gradientmodel.graph.
This is because shape change is saved in Graph::nodeargs which will not sync when serialize graph to proto.
constexpr const ORTCHAR_T* model_uri = "test.onnx";
std::shared_ptr<Model> forward_model;
ASSERT_STATUS_OK(Model::Load(model_uri, forward_model, nullptr, *logger_));
Graph& graph = forward_model->MainGraph();
NodeArg* arg = graph.GetNodeArg("Y");
ASSERT_TRUE(arg->Shape() == nullptr);
onnx::TensorShapeProto new_shape;
new_shape.add_dim()->set_dim_value(2);
new_shape.add_dim()->set_dim_value(3);
arg->SetShape(new_shape); // set the shape, but it's not sync and can not be seen in gradient_model
// This code snippet is to sync the shape change,
// If uncommented, the shape change can be seen in the gradient graph and this test succeed.
// but this is not done in current ort code.
// graph.Set_is_loaded_from_model_file(false);
// graph.SetGraphResolveNeeded();
// Graph::ResolveOptions resolve_options;
// ASSERT_STATUS_OK(graph.Resolve(resolve_options));
std::shared_ptr<Model> gradient_model;
ASSERT_STATUS_OK(Model::Load(forward_model->ToProto(), gradient_model, nullptr, *logger_));
Graph& gradient_graph = gradient_model->MainGraph();
NodeArg* gradient_arg = gradient_graph.GetNodeArg("Y");
ASSERT_TRUE(gradient_arg->Shape() != nullptr); // The shape change is not seen in the gradient graph and this will fail
This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.
Describe the issue
In ort training, we do graph transformation/optimization for forwardmodel and invoke 'Model::Load(forwardmodel->ToProto(), gradientmodel, nullptr, *logger_)' to get the gradient_model. If shapes of args are changed in the transformation/optimization step, the change will not be seen in gradientmodel.graph. This is because shape change is saved in Graph::nodeargs which will not sync when serialize graph to proto.
To reproduce
I construct a graph as:
the onnx file can be created by code:
and test code to reproduce the issue:
Urgency
No response
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
ed550b5fe5aa41e182db84d2b2f2fb768121fd7a
PyTorch Version
2.2.0
Execution Provider
Default CPU
Execution Provider Library Version
No response