Closed chethanpk closed 3 years ago
Hi @chethanpk, ORT Training currently doesn't target vision/image models. We had a preliminary implementation for BatchNormGradient, but it's not yet fully tested on various models.
Since I needed this working, here's a hacky patch (completely untested since it's still WIP but I'll update in case anything is significantly wrong. At least it doesn't complain about missing edges or unregistered ops anymore)
There are three main pieces that were missing:
diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.cc b/onnxruntime/core/providers/cpu/nn/batch_norm.cc
index 50d0bc21d..db50fe65e 100644
--- a/onnxruntime/core/providers/cpu/nn/batch_norm.cc
+++ b/onnxruntime/core/providers/cpu/nn/batch_norm.cc
@@ -29,11 +29,12 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 8, double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
+// We alias the running mean to the mean so it stays preserved across multiple batches
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 9, float,
- KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
+ KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 9, double,
- KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
+ KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h
index 46ca31053..d6800ca59 100644
--- a/onnxruntime/core/providers/cpu/nn/batch_norm.h
+++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h
@@ -35,7 +35,12 @@ class BatchNorm : public OpKernel {
is_spatial_(op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1) {
auto st = op_kernel_info.GetAttr<float>("epsilon", &epsilon_);
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
-
+ is_train_ = OpKernel::Node().OutputDefs().size() == 5;
+ if (is_train_) {
+ ORT_ENFORCE(is_spatial_ == true, "Non spatial convolution for training");
+ }
+ auto mt = op_kernel_info.GetAttr<float>("momentum", &momentum_);
+ ORT_ENFORCE(mt.IsOK(), mt.ErrorMessage());
// For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1)
// From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec
@@ -70,14 +75,61 @@ class BatchNorm : public OpKernel {
ConstEigenVectorArrayMap<T> scale_arr(scale->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> bias_arr(B->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
+ // The saved mean corresponds to the mean from this batch
+ auto* saved_mean = is_train_ ? p_op_kernel_context->Output(3, TensorShape({(int) C})) : nullptr;
+ auto* saved_var = is_train_ ? p_op_kernel_context->Output(4, TensorShape({(int) C})) : nullptr;
+
+ // The running mean corresponds to the mean from all the batches
+ // During inference this running mean is used as the mean for BN
+ auto* running_mean = is_train_ ? p_op_kernel_context->Output(1, TensorShape({(int) C})) : nullptr;
+ auto* running_var = is_train_ ? p_op_kernel_context->Output(2, TensorShape({(int) C})) : nullptr;
+
+ if (is_train_) {
+ EigenVectorArrayMap<T> saved_mean_arr(saved_mean->template MutableData<T>(), C);
+ EigenVectorArrayMap<T> saved_var_arr(saved_var->template MutableData<T>(), C);
+ saved_mean_arr.setZero();
+ saved_var_arr.setZero();
+
+ ConstEigenArrayMap<T> X_arr(X->template Data<T>(), sample_size, N * C);
+ for (size_t nc = 0; nc < N * C; ++nc) {
+ saved_mean_arr(nc % C) += X_arr.col(nc).sum();
+ }
+ saved_mean_arr /= N * sample_size;
+ for (size_t nc = 0; nc < N * C; ++nc) {
+ saved_var_arr(nc % C) += (X_arr.col(nc) - saved_mean_arr(nc % C)).matrix().squaredNorm();
+ }
+ saved_var_arr /= N * sample_size;
+
+ // Assume that running mean and variance are initialized properly in the model given to us
+ // Because we alias it, we have the past history here
+ EigenVectorArrayMap<T> running_mean_arr(
+ running_mean->template MutableData<T>(), C);
+ EigenVectorArrayMap<T> running_var_arr(
+ running_var->template MutableData<T>(), C);
+ running_mean_arr = running_mean_arr * momentum_ + saved_mean_arr * (1. - momentum_);
+ running_var_arr = running_var_arr * momentum_ + saved_var_arr * (1. - momentum_);
+ }
+
// Regardless of training or testing, we will apply the estimated mean
// and standard deviation to the input. For testing, they are
// specified directly by the input, and for training, they are computed
// by the op.
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(is_spatial_ ? C : sample_size_incl_all_channels);
- ConstEigenVectorArrayMap<T> var_arr(var->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
- inv_std = (var_arr + epsilon_).sqrt().inverse();
- ConstEigenVectorArrayMap<T> mean_arr(mean->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
+
+ if (!is_train_) {
+ ConstEigenVectorArrayMap<T> var_arr(var->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
+ inv_std = (var_arr + epsilon_).sqrt().inverse();
+ } else {
+ // Note that, to be consistent with cudnn, we will actually output saved inverse std
+ EigenVectorArrayMap<T> saved_inv_std(saved_var->template MutableData<T>(), C);
+ saved_inv_std = (saved_inv_std + epsilon_).inverse().sqrt();
+ inv_std = saved_inv_std;
+ }
+
+ // If we're training, do batch normalization based on computation from this batch
+ ConstEigenVectorArrayMap<T> mean_arr(
+ !is_train_ ? mean->template Data<T>() : saved_mean->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
+
// We can fuse the output computation as follows:
// ((x - est_mean) * (inv_var) * scale + bias
// to
@@ -105,7 +157,8 @@ class BatchNorm : public OpKernel {
protected:
float epsilon_;
+ float momentum_;
const bool is_spatial_;
- //int64_t is_test_; ignored in this implementation since we're doing inferencing only.
+ int64_t is_train_;
};
} // namespace onnxruntime
diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h
index 0aaf8a986..db6d203a4 100644
--- a/orttraining/orttraining/core/framework/gradient_graph_builder.h
+++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h
@@ -29,6 +29,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
STOP_GRADIENT_EDGES = {
{"Not", {0}},
{"And", {0, 1}},
+ {"BatchNormalization", {3, 4}},
{"Or", {0, 1}},
{"Xor", {0, 1}},
{"Equal", {0, 1}},
diff --git a/orttraining/orttraining/training_ops/systolic/systolic_training_kernels.cc b/orttraining/orttraining/training_ops/systolic/systolic_training_kernels.cc
index 342a64ae4..6c2ab8d49 100644
--- a/orttraining/orttraining/training_ops/systolic/systolic_training_kernels.cc
+++ b/orttraining/orttraining/training_ops/systolic/systolic_training_kernels.cc
@@ -13,6 +13,7 @@ namespace systolic {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kOnnxDomain, 9, ConvGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kOnnxDomain, 9, ConvGrad_nhwc);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kOnnxDomain, 9, MaxPoolGrad_nhwc);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kMSDomain, 1, BatchNormalizationGrad);
#endif
Status RegisterSystolicTrainingKernels(KernelRegistry& kernel_registry) {
@@ -21,6 +22,7 @@ Status RegisterSystolicTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kOnnxDomain, 9, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kOnnxDomain, 9, ConvGrad_nhwc)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kOnnxDomain, 9, MaxPoolGrad_nhwc)>,
+ BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kSystolicExecutionProvider, kMSDomain, 1, BatchNormalizationGrad)>,
#endif
};
I cleaned up the patch and opened a PR. @SherlockNoMad could you please review it?
This issue should be resolved with PR #6946. Closing it, please feel free to reopen.
Describe the bug This bug occurs when I was trying to train Mobilenetv2 with the cats and dogs dataset with the C++ Mnist training app repurposed for this. While building graph with grad nodes, specifically for the BatchNormgrad node, it fails to get the Nodearg for running mean. this is the exact error: GetArgDefsFromGraphFailed to get NodeArg with name mobilenetv20_features_batchnorm0_running_mean_grad. When I dug deeper, the node had 2 of the 4 input args correctly transferred to the gradent node (gamma and beta) and failed at running mean, which was returning null with the arg never being created. Is this expected behavior / is there support for Batch norm grad training op currently in the framework?
Urgency None
System information
To Reproduce I have attached the file I used to create the app which is based off the Mnist training app. Easiest way is to swap that in the Mnist solution, download the cats and dogs dataset and attached Mobilenet model, build the app and run it.
Expected behavior When the app is run, it exits with the below mentioned error: [W:onnxruntime:Default, graph.cc:84 onnxruntime::MergeShapeInfo] Error merging shape info for output. 'loss' source:{} target:{1}. Falling back to lenient merge. Fail: Not satisfied: node_arg optimizer_graph_builder.cc:30 onnxruntime::training::GetArgDefsFromGraphFailed to get NodeArg with name mobilenetv20_features_batchnorm0_running_mean_grad
src_and_model.zip