Closed attila-dusnoki-htec closed 3 months ago
Rewrite reduce mean changes: https://github.com/ROCm/AMDMIGraphX/tree/rewrite_reduce_mean
WIP matcher for pow2-div, if anyone wants to continue on
diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp
index 92eb489e6..a1d28e19c 100644
--- a/src/simplify_algebra.cpp
+++ b/src/simplify_algebra.cpp
@@ -1577,6 +1577,45 @@ struct find_split_transpose
}
};
+struct find_pow_div
+{
+ auto matcher() const
+ {
+ return match::name("div")(
+ match::arg(0)(match::name("pow")(match::arg(0)(match::any().bind("x")),
+ match::arg(1)(match::has_value(2.0f)))
+ .bind("pow")),
+ match::arg(1)(match::is_constant().bind("n")))
+ .bind("div");
+ }
+
+ void apply(module& m, const match::matcher_result& r) const
+ {
+ auto div = r.instructions["div"];
+ auto pow = r.instructions["pow"];
+ auto n = r.instructions["n"];
+ auto x = r.instructions["x"];
+ std::cerr << "div=";
+ div->debug_print();
+ std::cerr << "pow=";
+ pow->debug_print();
+ std::cerr << "n=";
+ n->debug_print();
+ std::cerr << "x=";
+ x->debug_print();
+ auto n_sqrt = m.insert_instruction(pow, make_op("sqrt"), n);
+ std::cerr << "n_sqrt=";
+ n_sqrt->debug_print();
+ auto new_x = m.insert_instruction(pow, make_op("div"), {x, n_sqrt});
+ std::cerr << "new_x=";
+ new_x->debug_print();
+ auto new_pow = m.insert_instruction(pow, pow->get_operator(), {new_x, pow->inputs().at(1)});
+ std::cerr << "new_pow=";
+ new_pow->debug_print();
+ m.debug_print();
+ }
+};
+
void simplify_algebra::apply(module& m) const
{
// Run simplifications multiple times
@@ -1600,6 +1639,7 @@ void simplify_algebra::apply(module& m) const
find_zero_ops{},
find_dot_add{},
find_conv_add{},
+ find_pow_div{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
The RMSNorm implementation currently can overflow with fp16. The prolem is that pow2 of the input is out of the fp16 range. Rewriting the equation can help reduce this. The idea here to get to a solution that is generally applicable, not just rmsnorm.
original form:
1/sqrt(mean(x^2)+eps)
rewritten:
1/sqrt(sum((x/sqrt(n))^2)+eps)
The
reduce_mean -> reduce_sum
andx^2/n -> (x/sqrt(n))^2
parts need to be generic