microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.52k stars 2.91k forks source link

[Training] Error building gradient graph for bert models for on-device training #22465

Open riccardopinosio opened 1 week ago

riccardopinosio commented 1 week ago

Describe the issue

Hello,

see also this discussion. I'm opening this one as I think it's an issue as sifting through previous issues training should work for bert models.

I am trying to generate artifacts for distilbert like so:

from transformers import AutoModel
import torch

modelName = "distilbert/distilbert-base-uncased"

model = AutoModel.from_pretrained(modelName)
example_input = (
    torch.randint(10, (1, 10)),
    torch.ones(10, dtype=int).view(1,10)
)
model_path = Path("./embedding_test")

torch.onnx.export(
    model,
    example_input,
    "./embedding_test",
    export_params=True,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    input_names=["input_ids",
                "attention_mask"],
    output_names=["output"])

onnx_model = onnx.load("./embedding_test")

p = Path("./embedding_training")
p.mkdir(exist_ok=True, parents=True)

artifacts.generate_artifacts(onnx_model,
frozen_params=[],
requires_grad=[initializer.name for initializer in onnx_model.graph.initializer],
loss=artifacts.LossType.MSELoss,
optimizer=artifacts.OptimType.AdamW,
loss_input_names=["output"],
artifact_directory=p)

The exported onnx model works perfectly for inference, but artifact generation throws up:

{
    "name": "RuntimeError",
    "message": "/onnxruntime_src/orttraining/orttraining/core/graph/gradient_builder_base.h:123 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t, bool) const i < node_->OutputDefs().size() was false. 
",
    "stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File /home/rpinosio/repositories/knights/hugot/python/generate_embedding_training_model.py:1
----> 1 artifacts.generate_artifacts(onnx_model,
      2 frozen_params=[],
      3 requires_grad=[initializer.name for initializer in onnx_model.graph.initializer],
      4 loss=artifacts.LossType.MSELoss,
      5 optimizer=artifacts.OptimType.AdamW,
      6 loss_input_names=[\"output\"],
      7 artifact_directory=p)

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/artifacts.py:193, in generate_artifacts(model, requires_grad, frozen_params, loss, optimizer, artifact_directory, prefix, ort_format, custom_op_library, additional_output_names, nominal_checkpoint, loss_input_names)
    186     custom_op_library_path = pathlib.Path(custom_op_library)
    188 with onnxblock.base(loaded_model, model_path), (
    189     onnxblock.custom_op_library(custom_op_library_path)
    190     if custom_op_library is not None
    191     else contextlib.nullcontext()
    192 ):
--> 193     _ = training_block(*[output.name for output in loaded_model.graph.output])
    194     training_model, eval_model = training_block.to_model_proto()
    195     model_params = training_block.parameters()

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204, in TrainingBlock.__call__(self, *args, **kwargs)
    196 self._parameters = _training_graph_utils.get_model_parameters(model, self._requires_grad, self._frozen_params)
    198 # Build the gradient graph. The gradient graph building is composed of the following steps:
    199 #   - Move all model parameters to model inputs.
    200 #   - Run orttraining graph transformers on the model.
    201 #   - Add the gradient graph to the optimized model.
    202 # The order of model inputs after gradient graph building is: user inputs, model parameters as inputs
    203 # The order of the model outputs is: user outputs, model parameter gradients (in the order of parameter inputs)
--> 204 self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
    205     model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
    206 )
    208 logging.debug(\"Adding gradient accumulation nodes for training block %s\", self.__class__.__name__)
    210 _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:130, in build_gradient_graph(model, requires_grad, frozen_params, output_names, custom_op_library)
    127 optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))
    129 # Assumption is that the first graph output is the loss output
--> 130 gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options)
    132 _reorder_outputs(gradient_model, output_names, requires_grad)
    134 return gradient_model, eval_model

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:84, in _gradient_model_for(model, requires_grad, loss_name, options)
     79 logging.debug(
     80     \"The loss output is %s. The gradient graph will be built starting from %s_grad.\", loss_name, loss_name
     81 )
     83 builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options)
---> 84 builder.build()
     85 return onnx.load_from_string(builder.get_model())

RuntimeError: /onnxruntime_src/orttraining/orttraining/core/graph/gradient_builder_base.h:123 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t, bool) const i < node_->OutputDefs().size() was false. 
}

Seems to have issues building the gradient graph as it gets out of bounds on OutputDefs.

To reproduce

See the code provided above.

Urgency

It's blocking the development of go bindings to onnx training which we want to use in our product.

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.19.2

PyTorch Version

2.4.1+cu121

Execution Provider

Default CPU

Execution Provider Library Version

No response

jkbeavers commented 1 day ago

This looks similar to the issue I had and fixed in https://github.com/microsoft/onnxruntime/pull/22414 . You can verify it's the same issue if you change your loss to crossentropy and see artifact generation succeed. If so, if you try using a nightly build, local build using master, or wait for the 1.20 release, it should be resolved (with mse loss).