AnyDSL / thorin

The Higher-Order Intermediate Representation
https://anydsl.github.io
GNU Lesser General Public License v3.0
151 stars 15 forks source link

Naked recursion segfaults the compiler #163

Open Hugobros3 opened 6 days ago

Hugobros3 commented 6 days ago

The following snippet (written using https://github.com/AnyDSL/artic/pull/23):

#[export]
fn main (a : i32) -> i32 {
    test(a)
}

fn test(a : i32) -> ! {
  test2(a)
}

fn test2(a : i32) -> ! {
  test(a)
}

Crashes the compiler because ETA-reduction in Importer is greedy and skips rebuilding test, only attempting to rebuild the callee test2, and then vice versa, endlessly up until we get a segfault.

Hugobros3 commented 3 days ago

This patch should fix the issue:

diff --git a/src/thorin/transform/importer.cpp b/src/thorin/transform/importer.cpp
index 2bebb9d07..6c97a0cb9 100644
--- a/src/thorin/transform/importer.cpp
+++ b/src/thorin/transform/importer.cpp
@@ -65,10 +65,15 @@ const Def* Importer::rewrite(const Def* const odef) {
                     goto rebuild;

                 if (body->args() == cont->params_as_defs()) {
-                    src().VLOG("simplify: continuation {} calls a free def: {}", cont->unique_name(), body->callee());
                     // We completely replace the original continuation
                     // If we don't do so, then we miss some simplifications
-                    return instantiate(body->callee());
+                    src().VLOG("simplify: continuation {} calls a free def: {}", cont->unique_name(), body->callee());
+                    // However, we need a safety in the case of 'naked' recursion:
+                    auto safety = cont->stub(*this, instantiate(cont->type())->as<FnType>());
+                    insert(odef, safety);
+                    auto eta_reduced = instantiate(body->callee());
+                    safety->jump(eta_reduced, safety->params_as_defs());
+                    return eta_reduced;
                 } else {
                     // build the permutation of the arguments
                     Array<size_t> perm(body->num_args());