microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.65k stars 2.93k forks source link

[Training] qat #18534

Open xll426 opened 11 months ago

xll426 commented 11 months ago

Describe the issue

RuntimeError: /onnxruntime_src/orttraining/orttraining/core/optimizer/qdq_fusion.cc:25 int onnxruntime::{anonymous}::ReplaceOrCreateZeroPointInitializer(onnxruntime::Graph&, onnxruntime::Node&) zero_point_tensor_int != nullptr was false. Expected: zero point initializer with name input-0_zero_point to be present in the graph. Actual: not found.

To reproduce

_ = mnist_with_loss(*[output.name for output in onnx_model.graph.output])

Urgency

No response

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.3

PyTorch Version

1.13.1

Execution Provider

CUDA

Execution Provider Library Version

cuda11.6

xadupre commented 11 months ago

Is it possible to know more about how you got this error? I assumed you call function create_training_artifacts. Was it a quantized model?

xll426 commented 11 months ago

When I run the qat.py script directly, it reports this error.

def create_training_artifacts(model_path, artifacts_dir, model_prefix): """Using onnxblock, this function creates the training artifacts for the model at the path provided.

The artifacts created can be used to train the model using onnxruntime.training.api. The artifacts are:
1. The training graph
2. The eval graph
3. The optimizer graph
4. The checkpoint file
"""

class MNISTWithLoss(onnxblock.TrainingBlock):
    def __init__(self):
        super().__init__()
        self.loss = onnxblock.loss.CrossEntropyLoss()

    def build(self, output_name):
        return self.loss(output_name)

mnist_with_loss = MNISTWithLoss()
print(model_path)
onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None

# Build the training and eval graphs
logging.info("Using onnxblock to create the training artifacts.")

# def traverse_nodes(graph):
#     for node in graph.node:
#         print(node.name)
#         if node.name=="input-0_zero_point":
#             print("Node Name:", node.name)
#             print("Op Type:", node.op_type)
#             print("Input(s):", [input for input in node.input])
#             print("Output(s):", [output for output in node.output])
#             print("Attributes:")
#             for attribute in node.attribute:
#                 print(f"  {attribute.name}: {attribute}")

# with onnxblock.onnx_model(onnx_model) as model_accessor:
with onnxblock.base(onnx_model):

    # main_graph = onnx_model.graph

    # 遍历所有节
    # traverse_nodes(main_graph)

    _ = mnist_with_loss(*[output.name for output in onnx_model.graph.output])
    # eval_model = model_accessor.eval_model
    training_model, eval_model = mnist_with_loss.to_model_proto()

# Build the optimizer graph
optimizer = onnxblock.optim.AdamW()
# with onnxblock.onnx_model() as accessor:
with onnxblock.empty_base() as accessor:
    _ = optimizer(mnist_with_loss.parameters())
    # optimizer_model = accessor.model
    optimizer_model = optimizer.to_model_proto()

# Create the training artifacts
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx")
logging.info(f"Saving the training model to {train_model_path}.")
onnx.save(onnx_model, train_model_path)
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx")
logging.info(f"Saving the eval model to {eval_model_path}.")
onnx.save(eval_model, eval_model_path)
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx")
logging.info(f"Saving the optimizer model to {optimizer_model_path}.")
onnx.save(optimizer_model, optimizer_model_path)
trainable_params, non_trainable_params = mnist_with_loss.parameters()
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt")
logging.info(f"Saving the checkpoint to {checkpoint_path}.")
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path)

return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path
xll426 commented 11 months ago

image

xll426 commented 11 months ago

it is a quantized model. We use this model for Quantization-Aware Training (QAT) https://github.com/microsoft/onnxruntime/tree/v1.16.2/orttraining/orttraining/test/python/qat_poc_example

baijumeswani commented 9 months ago

@xll426 QAT in ORT is currently in experimental phase. It is known that the feature is not complete yet. I will find some time to complete to fix the POC. Sorry about your experience.

baijumeswani commented 9 months ago

https://github.com/microsoft/onnxruntime/pull/19290nshould fix this. Sorry for the late response and fix.