migraphx-benchmark / AMDMIGraphX

AMD's graph optimization engine.
https://rocmsoftwareplatform.github.io/AMDMIGraphX/doc/html/
MIT License
0 stars 1 forks source link

Add rewrite logic to improve fp16 accuracy with reduce and pow2 div #163

Closed attila-dusnoki-htec closed 3 months ago

attila-dusnoki-htec commented 5 months ago

The RMSNorm implementation currently can overflow with fp16. The prolem is that pow2 of the input is out of the fp16 range. Rewriting the equation can help reduce this. The idea here to get to a solution that is generally applicable, not just rmsnorm.

original form: 1/sqrt(mean(x^2)+eps) ​

rewritten: 1/sqrt(sum((x/sqrt(n))^2)+eps)​

The reduce_mean -> reduce_sum and x^2/n -> (x/sqrt(n))^2 parts need to be generic

attila-dusnoki-htec commented 5 months ago

Rewrite reduce mean changes: https://github.com/ROCm/AMDMIGraphX/tree/rewrite_reduce_mean

attila-dusnoki-htec commented 4 months ago

WIP matcher for pow2-div, if anyone wants to continue on

diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp
index 92eb489e6..a1d28e19c 100644
--- a/src/simplify_algebra.cpp
+++ b/src/simplify_algebra.cpp
@@ -1577,6 +1577,45 @@ struct find_split_transpose
     }
 };

+struct find_pow_div
+{
+    auto matcher() const
+    {
+        return match::name("div")(
+                   match::arg(0)(match::name("pow")(match::arg(0)(match::any().bind("x")),
+                                                    match::arg(1)(match::has_value(2.0f)))
+                                     .bind("pow")),
+                   match::arg(1)(match::is_constant().bind("n")))
+            .bind("div");
+    }
+
+    void apply(module& m, const match::matcher_result& r) const
+    {
+        auto div = r.instructions["div"];
+        auto pow = r.instructions["pow"];
+        auto n   = r.instructions["n"];
+        auto x   = r.instructions["x"];
+        std::cerr << "div=";
+        div->debug_print();
+        std::cerr << "pow=";
+        pow->debug_print();
+        std::cerr << "n=";
+        n->debug_print();
+        std::cerr << "x=";
+        x->debug_print();
+        auto n_sqrt = m.insert_instruction(pow, make_op("sqrt"), n);
+        std::cerr << "n_sqrt=";
+        n_sqrt->debug_print();
+        auto new_x   = m.insert_instruction(pow, make_op("div"), {x, n_sqrt});
+        std::cerr << "new_x=";
+        new_x->debug_print();
+        auto new_pow = m.insert_instruction(pow, pow->get_operator(), {new_x, pow->inputs().at(1)});
+        std::cerr << "new_pow=";
+        new_pow->debug_print();
+        m.debug_print();
+    }
+};
+
 void simplify_algebra::apply(module& m) const
 {
     // Run simplifications multiple times
@@ -1600,6 +1639,7 @@ void simplify_algebra::apply(module& m) const
                             find_zero_ops{},
                             find_dot_add{},
                             find_conv_add{},
+                            find_pow_div{},
                             find_div_const{},
                             find_sub_const{},
                             find_rsqrt{},