microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
13.93k stars 2.81k forks source link

ORT execution fails when a gradient builder is not registered for module-local functions #9375

Open adk9 opened 2 years ago

adk9 commented 2 years ago

I'm trying to run a simple test case with ORTModule that has one linear layer inside a module as shown below:

import torch
from onnxruntime.training.ortmodule import ORTModule

class M(torch.nn.Module):
      def __init__(self):
            super().__init__()
            self.fc1 = torch.nn.Linear(2, 2)
      def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.fc1(x)

x = torch.randn(2, 2)

# Case 1: This works
model = ORTModule(M())
y_pred = model(x)

# Case 2: This doesn't work
class ORTModuleExtension(ORTModule):
    def __init__(self, module, debug_options=None):
        super().__init__(module, debug_options)
        for training_mode in [False, True]:
            self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {'export_modules_as_functions': True}

model = ORTModuleExtension(M())
y_pred = model(x)

In case 1 above, the ONNX graph exported by PyTorch has the module completely inlined:

case 1 graph

Whereas in case 2, it is wrapped inside a "flattened module" as shown below (the right image below is the graph corresponding to the function body of _FlattenedModule): case 2 main graphcase 2 flattenedmodule

The latter case should ideally work since modules/functions are supposed to be inlined during partitioning. But I see the error below in case 2:

File "onnxruntime/training/ortmodule/_graph_execution_manager.py", line 232, in _build_graph
self._graph_builder.build()
RuntimeError: onnxruntime/orttraining/orttraining/core/graph/gradient_builder_registry.cc:28 onnxruntime::training::GradientDef onnxruntime::training::GetGradientForOp(const onnxruntime::training::GradientGraphConfiguration&, onnxruntime::Graph*, const onnxruntime::Node*, const std::unordered_set<std::__cxx11::basic_string<char> >&, const std::unordered_set<std::__cxx11::basic_string<char> >&, const onnxruntime::logging::Logger&, std::unordered_set<std::__cxx11::basic_string<char> >&)
gradient_builder != nullptr was false. The gradient builder has not been registered: _FlattenedModule for node _FlattenedModule_4

System information ORT version: master Torch version: Branch onnx_ms_1

thiagocrepaldi commented 2 years ago

The exported model seems to be correct. User's model is wrapped around _FlattenedModule so that model input and output are flattened before exporting. Inside _FlattenedModule the user models is correctly represented as shown by the bottom right image.

The error seems to be coming from the core, during gradient build.

adk9 commented 2 years ago

Actually, this doesn't seem related to exporting all modules as functions. A similar error is seen when trying to export a single module as function. Consider the simplified example below:

import torch
from onnxruntime.training.ortmodule import ORTModule, DebugOptions

torch.manual_seed(42)
torch.device("cuda")

class M(torch.nn.Module):
      def __init__(self):
            super().__init__()
            self.fc1 = torch.nn.Linear(2, 2)
      def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.fc1(x)

class ORTModuleExtension(ORTModule):
    def __init__(self, module, debug_options=None):
        super().__init__(module, debug_options)
        for training_mode in [False, True]:
            self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {'export_modules_as_functions': {M}}

x = torch.randn(2, 2)
model = ORTModuleExtension(M(), DebugOptions(save_onnx=True, onnx_prefix='out_'))
for t in range(1):
    y_pred = model(x)
    print(t, y_pred)

In this case there's no "flattened module" and we see a graph similar to the one above (on the right) with a single module function M. The error encountered is:

RuntimeError: onnxruntime/orttraining/orttraining/core/graph/gradient_builder_registry.cc:28
onnxruntime::training::GradientDef onnxruntime::training::GetGradientForOp(const onnxruntime::training::GradientGraphConfiguration&, onnxruntime::Graph*, const onnxruntime::Node*, const std::unordered_set<std::__cxx11::basic_string<char> >&, const std::unordered_set<std::__cxx11::basic_string<char> >&, const onnxruntime::logging::Logger&, std::unordered_set<std::__cxx11::basic_string<char> >&)
gradient_builder != nullptr was false. The gradient builder has not been registered: M for node M_1

M, in this case, is a module-local function. Why is a gradient builder for M required? The function body for M should simply be inlined and used for gradient building.

thiagocrepaldi commented 2 years ago

A per ONNX spec (https://github.com/onnx/onnx/blob/master/docs/IR.md?plain=1#L32), we probably should inline unknown function body:

Functionality-wise, an ONNX compatible framework or runtime may inline a function body to execute it if it does not have corresponding implementation of the function.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.