Open xll426 opened 11 months ago
Is it possible to know more about how you got this error? I assumed you call function create_training_artifacts. Was it a quantized model?
When I run the qat.py script directly, it reports this error.
def create_training_artifacts(model_path, artifacts_dir, model_prefix): """Using onnxblock, this function creates the training artifacts for the model at the path provided.
The artifacts created can be used to train the model using onnxruntime.training.api. The artifacts are:
1. The training graph
2. The eval graph
3. The optimizer graph
4. The checkpoint file
"""
class MNISTWithLoss(onnxblock.TrainingBlock):
def __init__(self):
super().__init__()
self.loss = onnxblock.loss.CrossEntropyLoss()
def build(self, output_name):
return self.loss(output_name)
mnist_with_loss = MNISTWithLoss()
print(model_path)
onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None
# Build the training and eval graphs
logging.info("Using onnxblock to create the training artifacts.")
# def traverse_nodes(graph):
# for node in graph.node:
# print(node.name)
# if node.name=="input-0_zero_point":
# print("Node Name:", node.name)
# print("Op Type:", node.op_type)
# print("Input(s):", [input for input in node.input])
# print("Output(s):", [output for output in node.output])
# print("Attributes:")
# for attribute in node.attribute:
# print(f" {attribute.name}: {attribute}")
# with onnxblock.onnx_model(onnx_model) as model_accessor:
with onnxblock.base(onnx_model):
# main_graph = onnx_model.graph
# 遍历所有节
# traverse_nodes(main_graph)
_ = mnist_with_loss(*[output.name for output in onnx_model.graph.output])
# eval_model = model_accessor.eval_model
training_model, eval_model = mnist_with_loss.to_model_proto()
# Build the optimizer graph
optimizer = onnxblock.optim.AdamW()
# with onnxblock.onnx_model() as accessor:
with onnxblock.empty_base() as accessor:
_ = optimizer(mnist_with_loss.parameters())
# optimizer_model = accessor.model
optimizer_model = optimizer.to_model_proto()
# Create the training artifacts
train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx")
logging.info(f"Saving the training model to {train_model_path}.")
onnx.save(onnx_model, train_model_path)
eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx")
logging.info(f"Saving the eval model to {eval_model_path}.")
onnx.save(eval_model, eval_model_path)
optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx")
logging.info(f"Saving the optimizer model to {optimizer_model_path}.")
onnx.save(optimizer_model, optimizer_model_path)
trainable_params, non_trainable_params = mnist_with_loss.parameters()
checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt")
logging.info(f"Saving the checkpoint to {checkpoint_path}.")
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path)
return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path
it is a quantized model. We use this model for Quantization-Aware Training (QAT) https://github.com/microsoft/onnxruntime/tree/v1.16.2/orttraining/orttraining/test/python/qat_poc_example
@xll426 QAT in ORT is currently in experimental phase. It is known that the feature is not complete yet. I will find some time to complete to fix the POC. Sorry about your experience.
https://github.com/microsoft/onnxruntime/pull/19290nshould fix this. Sorry for the late response and fix.
Describe the issue
RuntimeError: /onnxruntime_src/orttraining/orttraining/core/optimizer/qdq_fusion.cc:25 int onnxruntime::{anonymous}::ReplaceOrCreateZeroPointInitializer(onnxruntime::Graph&, onnxruntime::Node&) zero_point_tensor_int != nullptr was false. Expected: zero point initializer with name input-0_zero_point to be present in the graph. Actual: not found.
To reproduce
_ = mnist_with_loss(*[output.name for output in onnx_model.graph.output])
Urgency
No response
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.16.3
PyTorch Version
1.13.1
Execution Provider
CUDA
Execution Provider Library Version
cuda11.6