microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.05k stars 2.83k forks source link

[On-device Training] Yolo custom loss #19464

Open Marouan-st opened 7 months ago

Marouan-st commented 7 months ago

Discussed in https://github.com/microsoft/onnxruntime/discussions/19390

Originally posted by **Marouan-st** February 2, 2024 Hello, I would like to implement a custom loss to be able to train on-device a yolov4-tiny model for object detection. To compute the loss some post-processing must be performed on the output of the model, like computing bboxes iou and sum several losses (class loss + confidence loss + iou loss: cross entropy losses): see https://www.nature.com/articles/s41598-021-02225-y/figures/3 I don't see how to implement all these needed computations in the custom loss, especially how to provide the different losses with the post-processed input, since onnx loss functions takes as input String arguments (input name). I'm using a yolov4-tiny model compiled from [darknet](https://github.com/AlexeyAB/darknet?tab=readme-ov-file#how-to-train-to-detect-your-custom-objects) and converted to onnx from a [tensorflow implementation](https://github.com/onnx/models/blob/main/validated/vision/object_detection_segmentation/yolov4/dependencies/Conversion.ipynb) of the model. The Torch implementation of this loss function (for the model i'm using) would look like this (inspired by this [yolov4 loss tensorflow implementation](https://github.com/hunglc007/tensorflow-yolov4-tflite/blob/master/core/yolov4.py#L320)): ```py def compute_loss(pred, conv, label, bboxes, STRIDES=[16, 32], NUM_CLASS=1, IOU_LOSS_THRESH=0.5, i=0): conv_shape = conv.size() batch_size = conv_shape[0] output_size = conv_shape[1] input_size = STRIDES[i] * output_size conv = torch.reshape(conv, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS)) conv_raw_conf = conv[:, :, :, :, 4:5] conv_raw_prob = conv[:, :, :, :, 5:] pred_xywh = pred[:, :, :, :, 0:4] pred_conf = pred[:, :, :, :, 4:5] label_xywh = label[:, :, :, :, 0:4] respond_bbox = label[:, :, :, :, 4:5] label_prob = label[:, :, :, :, 5:] giou = torch.unsqueeze(bbox_giou(pred_xywh, label_xywh), 0) # Here not sure... input_size = input_size.to(torch.float32) bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2) giou_loss = respond_bbox * bbox_loss_scale * (1- giou) iou = bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :]) max_iou = torch.unsqueeze(torch.max(iou), 0) respond_bgd = (1.0 - respond_bbox) * (max_iou < IOU_LOSS_THRESH).to(torch.float32) conf_focal = torch.pow(respond_bbox - pred_conf, 2) conf_loss = conf_focal * ( respond_bbox * torch.nn.functional.cross_entropy(input=conv_raw_conf, target=respond_bbox) + respond_bgd * torch.nn.functional.cross_entropy(input=conv_raw_conf, target=respond_bbox) ) prob_loss = respond_bbox * torch.nn.functional.cross_entropy(input=conv_raw_prob, target=label_prob) giou_loss = torch.mean(torch.sum(giou_loss)) conf_loss = torch.mean(torch.sum(conf_loss, axis=[1,2,3,4])) prob_loss = torch.mean(torch.sum(prob_loss, axis=[1,2,3,4])) return giou_loss + conf_loss + prob_loss def bbox_iou(bboxes1, bboxes2): """ @param bboxes1: (a, b, ..., 4) @param bboxes2: (A, B, ..., 4) x:X is 1:n or n:n or n:1 @return (max(a,A), max(b,B), ...) ex) (4,):(3,4) -> (3,) (2,1,4):(2,3,4) -> (2,3) """ bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3] bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3] bboxes1_coor = torch.concat( [ bboxes1[..., :2] - bboxes1[..., 2:] * 0.5, bboxes1[..., :2] + bboxes1[..., 2:] * 0.5, ], axis=-1, ) bboxes2_coor = torch.concat( [ bboxes2[..., :2] - bboxes2[..., 2:] * 0.5, bboxes2[..., :2] + bboxes2[..., 2:] * 0.5, ], axis=-1, ) left_up = torch.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2]) right_down = torch.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:]) inter_section = torch.maximum(right_down - left_up, 0.0) inter_area = inter_section[..., 0] * inter_section[..., 1] union_area = bboxes1_area + bboxes2_area - inter_area iou = torch.div(inter_area, union_area) return iou def bbox_giou(bboxes1, bboxes2): """ Generalized IoU @param bboxes1: (a, b, ..., 4) @param bboxes2: (A, B, ..., 4) x:X is 1:n or n:n or n:1 @return (max(a,A), max(b,B), ...) ex) (4,):(3,4) -> (3,) (2,1,4):(2,3,4) -> (2,3) """ bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3] bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3] bboxes1_coor = torch.concat( [ bboxes1[..., :2] - bboxes1[..., 2:] * 0.5, bboxes1[..., :2] + bboxes1[..., 2:] * 0.5, ], axis=-1, ) bboxes2_coor = torch.concat( [ bboxes2[..., :2] - bboxes2[..., 2:] * 0.5, bboxes2[..., :2] + bboxes2[..., 2:] * 0.5, ], axis=-1, ) left_up = torch.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2]) right_down = torch.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:]) inter_section = torch.maximum(right_down - left_up, 0.0) inter_area = inter_section[..., 0] * inter_section[..., 1] union_area = bboxes1_area + bboxes2_area - inter_area iou = torch.div(inter_area, union_area) enclose_left_up = torch.minimum(bboxes1_coor[..., :2], bboxes2_coor[..., :2]) enclose_right_down = torch.maximum( bboxes1_coor[..., 2:], bboxes2_coor[..., 2:] ) enclose_section = enclose_right_down - enclose_left_up enclose_area = enclose_section[..., 0] * enclose_section[..., 1] giou = iou - torch.div(enclose_area - union_area, enclose_area) return giou ``` Any suggestions? Thank you
OAHLSTM commented 6 months ago

Hello, I'm working on a similar topic aiming to retrain a fine tuned YoloV8 on the device using Onnxruntime training API, and I'm kinda struggling to define the loss functions as Onnxblocks, @baijumeswani any help here would be appreciated ?

Thank you.

baijumeswani commented 6 months ago

Hi there. I provided some suggestions here: https://github.com/microsoft/onnxruntime/issues/19464.

The idea being, if the loss is difficult to express in onnxblock, you could try to create an onnx model from pytorch that contains the loss embedded inside it.

class MyPTModelWithLoss:
    def __init__(self):
         ...

    def forward(self, ...):
        p, q, r = compute_logits()
        loss = loss1(p) + loss2(q) + loss3(r)
        return loss

pt_model = MyPTModelWithLoss(...)
torch.onnx.export(pt_model, ...)

onnx_model = onnx.load(<exported_onnx_model_path>)
artifacts.generate_artifacts(onnx_model, requires_grad=[...], frozen_params=[...], loss=None, optimizer=...)

This might become more complex if you already have the onnx model and do not have the access to the pytorch model to add the loss function to. In that case, we can try to support your scenario with onnxblock. So, if this is where you are, please share your loss function, and I'll try to make onnxblock support that scenario.

Marouan-st commented 5 months ago

Hi,

I do have access to the YOLOv8n torch model from ultralytics. (ultralytics doc)

I tried to include the loss computation into my model and export it to onnx as follows:

from ultralytics import YOLO
import torch

# Load a model (with pretrained weights)
model = YOLO("yolov8n.pt") 

class YOLOv8nWithLoss(torch.nn.Module):
    def __init__(self, yolov8_model):
        super(YOLOv8nWithLoss, self).__init__()
        self.model = yolov8_model

    def forward(self, batch, targets):
        outputs = self.model.model(batch)
        loss = self.model.model.loss(targets, outputs)
        return loss

model_with_loss = YOLOv8nWithLoss(model)
model_with_loss.model.train()

# Export the model to ONNX.
model_name = "yolov8n_with_loss_eval_mode"

# Use opset_version < 18, otherwise ReduceMin error
torch.onnx.export(model_with_loss, (torch.randn(1, 3, 640, 640), torch.Tensor([[0,1,0.85,0.45,0.57,0.98]])),
                  f"training_artifacts/{model_name}.onnx",
                  input_names=["images", "targets"], output_names=["loss"],
                  dynamic_axes={"images": {0: "batch", 2: "height", 3: "width"},
                                "targets": {0: "batch"},
                                "loss": {0: "batch", 2: "anchors"}}, training=torch.onnx.TrainingMode.PRESERVE, opset_version=17)

Loss implementation is available here (cf class v8DetectionLoss)

The export goes well (onnx graph yolov8n_with_loss_train_mode.zip) but with the following warning:

/venv/object_detection/lib/python3.10/site-packages/torch/onnx/utils.py:1686](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/venv/object_detection/lib/python3.10/site-packages/torch/onnx/utils.py:1686): UserWarning: The exported ONNX model failed ONNX shape inference. The model will not be executable by the ONNX Runtime. If this is unintended and you believe there is a bug, please report an issue at https://github.com/pytorch/pytorch/issues. Error reported by strict ONNX shape inference: [ShapeInferenceError] (op_type:Concat, node name: [/Concat_21](https://file+.vscode-resource.vscode-cdn.net/Concat_21)): inputs has inconsistent type tensor(int32) (Triggered internally at ../torch/csrc/jit/serialization/export.cpp:1415.)
  _C._check_onnx_proto(proto)

Then, when I try to generate the artifacts I get the following error:

InferenceError                            Traceback (most recent call last)
Cell In[9], [line 14](vscode-notebook-cell:?execution_count=9&line=14)
      [6](vscode-notebook-cell:?execution_count=9&line=6) frozen_params = [
      [7](vscode-notebook-cell:?execution_count=9&line=7)    param.name
      [8](vscode-notebook-cell:?execution_count=9&line=8)    for param in onnx_model.graph.initializer
      [9](vscode-notebook-cell:?execution_count=9&line=9)    if param.name not in requires_grad
     [10](vscode-notebook-cell:?execution_count=9&line=10) ]
     [13](vscode-notebook-cell:?execution_count=9&line=13) # Generate the training artifacts.
---> [14](vscode-notebook-cell:?execution_count=9&line=14) artifacts.generate_artifacts(
     [15](vscode-notebook-cell:?execution_count=9&line=15)    onnx_model,
     [16](vscode-notebook-cell:?execution_count=9&line=16)    requires_grad=requires_grad,
     [17](vscode-notebook-cell:?execution_count=9&line=17)    frozen_params=frozen_params,
     [18](vscode-notebook-cell:?execution_count=9&line=18)    loss=None,
     [19](vscode-notebook-cell:?execution_count=9&line=19)    optimizer=artifacts.OptimType.AdamW,
     [20](vscode-notebook-cell:?execution_count=9&line=20)    artifact_directory="training_artifacts"
     [21](vscode-notebook-cell:?execution_count=9&line=21) )

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154), in generate_artifacts(model, requires_grad, frozen_params, loss, optimizer, artifact_directory, **extra_options)
    [149](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:149)     custom_op_library = pathlib.Path(custom_op_library)
    [151](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:151) with onnxblock.base(model), onnxblock.custom_op_library(
    [152](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:152)     custom_op_library
    [153](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:153) ) if custom_op_library is not None else contextlib.nullcontext():
--> [154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154)     _ = training_block(*[output.name for output in model.graph.output])
    [155](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:155)     training_model, eval_model = training_block.to_model_proto()
    [156](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:156)     model_params = training_block.parameters()

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:188](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:188), in TrainingBlock.__call__(self, *args, **kwargs)
    [184](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:184) self.base = accessor._GLOBAL_ACCESSOR.model
    [186](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:186) logging.debug("Building training block %s", self.__class__.__name__)
--> [188](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:188) output = self.build(*args, **kwargs)
    [190](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:190) model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)
    [192](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:192) _graph_utils.register_graph_outputs(model, output)

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:124](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:124), in generate_artifacts.<locals>._TrainingBlock.build(self, *inputs_to_loss)
    [121](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:121)     else:
    [122](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:122)         return (loss_output, *tuple(extra_options["additional_output_names"]))
--> [124](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:124) return self._loss(*inputs_to_loss)

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:50](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:50), in Block.__call__(self, *args, **kwargs)
     [46](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:46) logging.debug("Building block: %s", self.__class__.__name__)
     [48](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:48) output = self.build(*args, **kwargs)
---> [50](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:50) onnx.checker.check_model(self.base, True)
     [52](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/blocks.py:52) return output

File [~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:148](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:148), in check_model(model, full_check, skip_opset_compatibility_check)
    [144](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:144) if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF:
    [145](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:145)     raise ValueError(
    [146](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:146)         "This protobuf of onnx model is too large (>2GB). Call check_model with model path instead."
    [147](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:147)     )
--> [148](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnx/checker.py:148) C.check_model(protobuf_string, full_check, skip_opset_compatibility_check)

InferenceError: [ShapeInferenceError] (op_type:Concat, node name: [/Concat_21](https://file+.vscode-resource.vscode-cdn.net/Concat_21)): inputs has inconsistent type tensor(int32)

@baijumeswani Any idea of what could cause this error?

Thank you

baijumeswani commented 5 months ago

Looking at your model after doing shape inferencing on it, I see the concat node like so:

image

The concat node is trying to concat tensors of different types (int64 and int32) and this will fail with onnxruntime. All the types being concatenated need to be the same. You can try to add a cast node to cast the int32 to int64 and see how that goes.

Marouan-st commented 5 months ago

I found the int32 tensor and changed its type to int64 and I don't get the ShapeInference error anymore, thanks.

Now I get another error when trying to generate the artifacts:

RuntimeError                              Traceback (most recent call last)
Cell In[16], [line 15](vscode-notebook-cell:?execution_count=16&line=15)
      [7](vscode-notebook-cell:?execution_count=16&line=7) frozen_params = [
      [8](vscode-notebook-cell:?execution_count=16&line=8)    param.name
      [9](vscode-notebook-cell:?execution_count=16&line=9)    for param in onnx_model.graph.initializer
     [10](vscode-notebook-cell:?execution_count=16&line=10)    if param.name not in requires_grad
     [11](vscode-notebook-cell:?execution_count=16&line=11) ]
     [14](vscode-notebook-cell:?execution_count=16&line=14) # Generate the training artifacts.
---> [15](vscode-notebook-cell:?execution_count=16&line=15) artifacts.generate_artifacts(
     [16](vscode-notebook-cell:?execution_count=16&line=16)    onnx_model,
     [17](vscode-notebook-cell:?execution_count=16&line=17)    requires_grad=requires_grad,
     [18](vscode-notebook-cell:?execution_count=16&line=18)    frozen_params=frozen_params,
     [19](vscode-notebook-cell:?execution_count=16&line=19)    loss=None,
     [20](vscode-notebook-cell:?execution_count=16&line=20)    optimizer=artifacts.OptimType.AdamW,
     [21](vscode-notebook-cell:?execution_count=16&line=21)    artifact_directory="training_artifacts"
     [22](vscode-notebook-cell:?execution_count=16&line=22) )

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154), in generate_artifacts(model, requires_grad, frozen_params, loss, optimizer, artifact_directory, **extra_options)
    [149](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:149)     custom_op_library = pathlib.Path(custom_op_library)
    [151](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:151) with onnxblock.base(model), onnxblock.custom_op_library(
    [152](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:152)     custom_op_library
    [153](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:153) ) if custom_op_library is not None else contextlib.nullcontext():
--> [154](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:154)     _ = training_block(*[output.name for output in model.graph.output])
    [155](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:155)     training_model, eval_model = training_block.to_model_proto()
    [156](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/artifacts.py:156)     model_params = training_block.parameters()

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204), in TrainingBlock.__call__(self, *args, **kwargs)
    [196](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:196) self._parameters = _training_graph_utils.get_model_parameters(model, self._requires_grad, self._frozen_params)
    [198](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:198) # Build the gradient graph. The gradient graph building is composed of the following steps:
    [199](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:199) #   - Move all model parameters to model inputs.
    [200](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:200) #   - Run orttraining graph transformers on the model.
    [201](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:201) #   - Add the gradient graph to the optimized model.
    [202](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:202) # The order of model inputs after gradient graph building is: user inputs, model parameters as inputs
    [203](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:203) # The order of the model outputs is: user outputs, model parameter gradients (in the order of parameter inputs)
--> [204](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204) self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
    [205](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:205)     model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
    [206](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:206) )
    [208](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:208) logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__)
    [210](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/onnxblock.py:210) _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)

File [~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127), in build_gradient_graph(model, requires_grad, frozen_params, output_names, custom_op_library)
    [124](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:124) if custom_op_library is not None:
    [125](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:125)     options.register_custom_ops_library(os.fspath(custom_op_library))
--> [127](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:127) optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))
    [129](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:129) # Assumption is that the first graph output is the loss output
    [130](https://file+.vscode-resource.vscode-cdn.net/local/home/boullima/source/on-device-training/use_cases/onnxruntime/object_detection/yolov8_ort/~/venv/object_detection/lib/python3.10/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:130) gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options)

RuntimeError: [/local/home/user/tools/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:1010](https://file+.vscode-resource.vscode-cdn.net/local/home/user/tools/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:1010) onnxruntime::python::addObjectMethodsForTraining(pybind11::module&)::<lambda(const pybind11::bytes&, const std::unordered_set<std::__cxx11::basic_string<char> >&, onnxruntime::python::PySessionOptions*)> [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (/Reshape_52_output_0_log_prob).

Here's the corrected graph yolov8n_with_loss_train_mode_cast.zip and how i generate the artifacts:

# Load the onnx model.
model_name = "yolov8n_with_loss_train_mode_cast"
onnx_model = onnx.load(f"{model_name}.onnx")

requires_grad = ["model.model.model.22.cv3.2.2.weight", "model.model.model.22.cv3.2.2.bias"]
frozen_params = [
   param.name
   for param in onnx_model.graph.initializer
   if param.name not in requires_grad
]

# Generate the training artifacts.
artifacts.generate_artifacts(
   onnx_model,
   requires_grad=requires_grad,
   frozen_params=frozen_params,
   loss=None,
   optimizer=artifacts.OptimType.AdamW,
   artifact_directory="training_artifacts"
)
baijumeswani commented 5 months ago

^ Seems like a bug. I will add a pull-request to address this issue.

Marouan-st commented 5 months ago

^ Seems like a bug. I will add a pull-request to address this issue.

Hello @baijumeswani, any update regarding this bug?

baijumeswani commented 5 months ago

I addressed the issue you highlighted here: https://github.com/microsoft/onnxruntime/pull/20016

However, there is still another problem that is that the model has a ReduceMax node. ORT training does not have a gradient kernel for the ReduceMax node defined yet. And so the gradient graph building fails.

Marouan-st commented 5 months ago

Ok, thank you for your support. Do you know how I could replace these ReduceMax nodes by supported operations for training? Also, there are ReduceMin nodes in the graph, is there a gradient kernel for these nodes?

Marouan-st commented 5 months ago

Hello @baijumeswani, ReduceMax and ReduceMin operations are only used in the loss computation, so the gradient is not really required for these operations. I'm following your suggested approach based on creating an onnx model from pytorch that contains the loss embedded inside it. I have two questions:

1- The model I provided contains a forward graph + loss computation, so i'm wondering if there is any way to build the gradient graph only for the forward part of the model?

2- In that case, how are we supposed to feed the loss function to the loss argument of the generate_artifacts function?

Thank you