microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
13.48k stars 2.75k forks source link

[Training][api:C++][feature request] Support Model Forward Output and Backward Gradient Extraction in ONNX runtime training #16232

Open zjc664656505 opened 1 year ago

zjc664656505 commented 1 year ago

Describe the issue

Dear ONNX Community,

Currently, I'm work on a decentralized model parallel project using ONNX. Given a model, we first shard it to n sub-modules given n total number of devices. Then, we are building p2p connection between devices in the computing cluster.

In this setting, what we want to achieve is that device i will send its sub-module forward output to device i+1 and the device i+1 will send its sub-module's gradient to device i for backpropagation.

As what we have seen so far, ONNX runtime training is achieved via the method TrainStep in the C++ API, which integrate the forward and backward into a single method.

With this method, we can get the sub-module's forward training loss. However, we cannot do backward without knowing the dependent sub-module's gradient on current sub-module, say now we want to do sub-module backward on device i, we need to have the gradient of the sub-module from device i+1. Also, we cannot get the forward model output as well using the TrainStep method, even though on the official ONNX C++ training API page, it shows:

image

Example output using the TrainStep can be seen in the image shown below.

image

As you can see, the TrainStep function only returns a vector with size of 1, in which it only contains the training loss.

Therefore, I wish to ask whether there is a way to extract the backward gradient and forward output of a given model from the training session in the onnx runtime training C++ API? We are very new to onnx and wish to get some help, and if there is approach for doing this, could we have a reference to it?

Thanks!

@baijumeswani @ashari4 @justinchuby @pengwa

To reproduce

Please refer to the on-device-training train.cpp file.

Urgency

This is urgent.

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15

ONNX Runtime API

C++

Execution Provider

Default CPU

bitnom commented 4 weeks ago

I found this researching the same as 143 . Just curious about the state/progress. This feature would be a big deal.