AnyDSL / thorin

The Higher-Order Intermediate Representation
https://anydsl.github.io
GNU Lesser General Public License v3.0
151 stars 15 forks source link

[New PE] Divergent partial evaluation #82

Closed madmann91 closed 6 years ago

madmann91 commented 6 years ago

Compile with -emit-thorin -Othorin.

struct SmallStack {
    write: fn (int, float) -> (),
    read:  fn (int) -> float
}

fn @(true) make_small_stack(n: int) -> SmallStack {
    fn @(?begin & ?end) make_small_stack_helper(begin: int, end: int) -> SmallStack {
        if begin == end {
            SmallStack {
                write: @(true) |_, _| (),
                read:  @(true) |_| 0.0f
            }
        } else if begin + 1 == end {
            let mut val : float;
            SmallStack {
                write: @(true) |i, v| val = v,
                read:  @(true) |i| val
            }
        } else {
            let m = (begin + end) / 2;
            let left  = make_small_stack_helper(begin, m);
            let right = make_small_stack_helper(m, end);
            SmallStack {
                write: @(true) |i, v| if i < m { left.write(i, v) } else { right.write(i, v) },
                read:  @(true) |i|    if i < m { left.read(i)     } else { right.read(i)     }
            }
        }
    }

    make_small_stack_helper(0, n)
}

fn @(?a & ?b) range_step(a: int, b: int, @(true) c: int, @(true) body: fn (int) -> ()) -> () {
    if a < b {
        body(a);
        range_step(a + c, b, c, body, return)
    }
}

fn @(true) range(a: int, b: int, body: fn (int) -> ()) -> () {
    range_step(a, b, 1, body);
}

fn main(res: &mut [float]) -> () {
    for n in range(0, 8) {
        let tmp = make_small_stack(n);
        for i in range(0, n) {
            tmp.write(i, i as float * 42.0f);
        }
        let mut sum = 0.0f;
        for i in range(0, n) {
            let v = tmp.read(i);
            sum += v as float;
        }
        res(n) = sum;
    }
}
madmann91 commented 6 years ago

As a side note, this pattern is currently the only way to get compile-time sized arrays in Impala. This might be a bit contrived, but it generates pretty efficient code since every element in the array is mapped to one alloca, and both our mem2reg and LLVM's mem2reg are able to take care of this.

madmann91 commented 6 years ago

Besides, replacing the outer for loop in main() by let n = <some constant> makes the code compile.

leissa commented 6 years ago

This is actually a limitation of the current approach. We'll need the annotation at the call site, too:

fn @(?a & ?b) range_step(a: int, b: int, @(true) c: int, @(true) body: fn (int) -> ()) -> () {
    if a < b {
        @body(a);
        range_step(a + c, b, c, body, return)
    }
}

Otherwise, the constant isn't propagated to the inner loop.

madmann91 commented 6 years ago

But why does it work with a constant n, then?

leissa commented 6 years ago

Because then the inner loop knows the constant.

madmann91 commented 6 years ago

Yes, but still, the other two for loops also require inlining of tmp.write() or tmp.read(). Otherwise, the value that is written into sum is not known. Yet, this seems to work without any call site annotation. Is it just luck?

leissa commented 6 years ago

But those function expressions are annotated with @(true), no?

leissa commented 6 years ago

Btw actually after PE the value of sum isn't known, only after applying mem2reg the value will be known. This is another limitation. But this one holds for current master just as well. We would either need to do mem2reg on the fly or perform loads and stores during PE in order to deal with that.

See https://github.com/AnyDSL/thorin/projects/1#card-2146514

madmann91 commented 6 years ago

But those function expressions are annotated with @(true), no?

Well that still does not explain it, because if you replace the outer for loop by

range(0, 8, @(true) |n| {
    let tmp = make_small_stack(n);
    /* ... */
});

This example still does not work.

leissa commented 6 years ago

This is yet a different problem, I'm currently working on :)

madmann91 commented 6 years ago

This is now fixed with your recent changes (I just tested it)!

madmann91 commented 6 years ago

It only worked with range(0, 2, @ |n| /*...*/), but not for something bigger than 2, e.g. range(0, 3, @ |n| /*...*/). My bad.

leissa commented 6 years ago
struct SmallStack {
    write: fn (int, float) -> (),
    read:  fn (int) -> float
}

fn @(true) make_small_stack(n: int) -> SmallStack {
    fn @(?begin & ?end) make_small_stack_helper(begin: int, end: int) -> SmallStack {
        if begin == end {
            SmallStack {
                write: @(true) |_, _| (),
                read:  @(true) |_| 0.0f
            }
        } else if begin + 1 == end {
            let mut val : float;
            SmallStack {
                write: @(true) |i, v| val = v,
                read:  @(true) |i| val
            }
        } else {
            let m = (begin + end) / 2;
            let left  = make_small_stack_helper(begin, m);
            let right = make_small_stack_helper(m, end);
            SmallStack {
                write: @(true) |i, v| if i < m { left.write(i, v) } else { right.write(i, v) },
                read:  @(true) |i|    if i < m { left.read(i)     } else { right.read(i)     }
            }
        }
    }

    make_small_stack_helper(0, n)
}

fn @(?a & ?b) range_step(a: int, b: int, @(true) c: int, @(true) body: fn (int) -> ()) -> () {
    if a < b {
        body(a);
        range_step(a + c, b, c, body, return)
    }
}

fn @(true) range(a: int, b: int, body: fn (int) -> ()) -> () {
    range_step(a, b, 1, body);
}

fn main(res: &mut [float]) -> () {
    range(0, 8, @(true) |n| {
        let tmp = make_small_stack(n);
        range(0, n, @|i| {
            tmp.write(i, i as float * 42.0f);
        });
        let mut sum = 0.0f;
        range(0, n, @|i| {
            let v = tmp.read(i);
            sum += v as float;
        });
        res(n) = sum;
    });
}

With my latest change, this works now. I'll keep this issue opened until I have implemented the @call(...) feature.

madmann91 commented 6 years ago

No, this still does not work. The expect output should just be loads/stores/adds, and no comparisons. I get this:

module 'new_pe'

main_18038(mem mem_18039, [pf32]* res_18040, fn(mem, ()) return_18041) extern 
    pf32* n_18075 = lea res_18040, qs32 0
    pf32* n_18082 = lea res_18040, qs32 3
    pf32* n_18078 = lea res_18040, qs32 1
    pf32* n_18080 = lea res_18040, qs32 2
    (mem, frame) _18052 = enter mem_18039
    frame _18054 = extract _18052, qu32 1
    mem _18073 = extract _18052, qu32 0
    pf32* val_18215 = slot _18054
    pf32* val_18070 = slot _18054
    pf32* val_18151 = slot _18054
    pf32* val_18161 = slot _18054
    pf32* val_18055 = slot _18054
    pf32* val_18061 = slot _18054
    pf32* val_18067 = slot _18054
    pf32* val_18102 = slot _18054
    pf32* val_18158 = slot _18054
    pf32* val_18113 = slot _18054
    pf32* val_18241 = slot _18054
    pf32* val_18244 = slot _18054
    pf32* val_18222 = slot _18054
    pf32* val_18232 = slot _18054
    pf32* val_18116 = slot _18054
    pf32* val_18175 = slot _18054
    pf32* val_18225 = slot _18054
    pf32* val_18166 = slot _18054
    pf32* val_18107 = slot _18054
    pf32* val_18172 = slot _18054
    pf32* val_18097 = slot _18054
    pf32* val_18235 = slot _18054
    mem _18077 = store _18073, n_18075, pf32 0
    mem _18079 = store _18077, n_18078, pf32 0
    mem _18081 = store _18079, n_18080, pf32 42
    mem _18084 = store _18081, n_18082, pf32 126
    lambda_18042(_18084, qs32 0, lambda_18085)

    lambda_18085(mem lambda_18086)
        lambda_18042(lambda_18086, qs32 1, lambda_18087)

    lambda_18087(mem lambda_18088)
        pf32* n_18129 = lea res_18040, qs32 4
        mem _18119 = store lambda_18088, val_18067, pf32 84
        mem _18120 = store _18119, val_18070, pf32 126
        (mem, pf32) _18121 = load _18120, val_18055
        pf32 _18130 = extract _18121, qu32 1
        mem _18122 = extract _18121, qu32 0
        pf32 _18131 = add pf32 0, _18130
        (mem, pf32) _18123 = load _18122, val_18061
        pf32 _18132 = extract _18123, qu32 1
        mem _18124 = extract _18123, qu32 0
        pf32 _18133 = add _18131, _18132
        (mem, pf32) _18125 = load _18124, val_18067
        mem _18126 = extract _18125, qu32 0
        pf32 _18134 = extract _18125, qu32 1
        (mem, pf32) _18127 = load _18126, val_18070
        pf32 _18135 = add _18133, _18134
        mem _18128 = extract _18127, qu32 0
        pf32 _18136 = extract _18127, qu32 1
        pf32 _18137 = add _18135, _18136
        mem _18138 = store _18128, n_18129, _18137
        lambda_18089(_18138, qs32 0, lambda_18139)

    lambda_18139(mem lambda_18140)
        lambda_18089(lambda_18140, qs32 1, lambda_18141)

    lambda_18141(mem lambda_18142)
        mem _18177 = store lambda_18142, val_18107, pf32 84
        pf32* n_18191 = lea res_18040, qs32 5
        mem _18178 = store _18177, val_18113, pf32 126
        mem _18180 = store _18178, val_18116, pf32 168
        (mem, pf32) _18181 = load _18180, val_18097
        pf32 _18192 = extract _18181, qu32 1
        mem _18182 = extract _18181, qu32 0
        pf32 _18193 = add pf32 0, _18192
        (mem, pf32) _18183 = load _18182, val_18102
        pf32 _18194 = extract _18183, qu32 1
        mem _18184 = extract _18183, qu32 0
        pf32 _18195 = add _18193, _18194
        (mem, pf32) _18185 = load _18184, val_18107
        pf32 _18196 = extract _18185, qu32 1
        mem _18186 = extract _18185, qu32 0
        pf32 _18197 = add _18195, _18196
        (mem, pf32) _18187 = load _18186, val_18113
        mem _18188 = extract _18187, qu32 0
        pf32 _18198 = extract _18187, qu32 1
        (mem, pf32) _18189 = load _18188, val_18116
        pf32 _18199 = add _18197, _18198
        mem _18190 = extract _18189, qu32 0
        pf32 _18200 = extract _18189, qu32 1
        pf32 _18201 = add _18199, _18200
        mem _18202 = store _18190, n_18191, _18201
        lambda_18143(_18202, qs32 0, lambda_18203)

    lambda_18203(mem lambda_18204)
        lambda_18143(lambda_18204, qs32 1, lambda_18205)

    lambda_18205(mem lambda_18206)
        pf32* n_18263 = lea res_18040, qs32 6
        mem _18246 = store lambda_18206, val_18161, pf32 84
        mem _18247 = store _18246, val_18166, pf32 126
        mem _18248 = store _18247, val_18172, pf32 168
        mem _18250 = store _18248, val_18175, pf32 210
        (mem, pf32) _18251 = load _18250, val_18151
        pf32 _18264 = extract _18251, qu32 1
        mem _18252 = extract _18251, qu32 0
        pf32 _18265 = add pf32 0, _18264
        (mem, pf32) _18253 = load _18252, val_18158
        pf32 _18266 = extract _18253, qu32 1
        mem _18254 = extract _18253, qu32 0
        pf32 _18267 = add _18265, _18266
        (mem, pf32) _18255 = load _18254, val_18161
        pf32 _18268 = extract _18255, qu32 1
        mem _18256 = extract _18255, qu32 0
        pf32 _18269 = add _18267, _18268
        (mem, pf32) _18257 = load _18256, val_18166
        pf32 _18270 = extract _18257, qu32 1
        mem _18258 = extract _18257, qu32 0
        pf32 _18271 = add _18269, _18270
        (mem, pf32) _18259 = load _18258, val_18172
        mem _18260 = extract _18259, qu32 0
        pf32 _18272 = extract _18259, qu32 1
        (mem, pf32) _18261 = load _18260, val_18175
        pf32 _18273 = add _18271, _18272
        mem _18262 = extract _18261, qu32 0
        pf32 _18274 = extract _18261, qu32 1
        pf32 _18275 = add _18273, _18274
        mem _18276 = store _18262, n_18263, _18275
        lambda_18207(_18276, qs32 0, lambda_18277)

    lambda_18277(mem lambda_18278)
        lambda_18207(lambda_18278, qs32 1, lambda_18279)

    lambda_18279(mem lambda_18280)
        mem _18281 = store lambda_18280, val_18225, pf32 84
        mem _18282 = store _18281, val_18232, pf32 126
        pf32* n_18302 = lea res_18040, qs32 7
        mem _18283 = store _18282, val_18235, pf32 168
        mem _18284 = store _18283, val_18241, pf32 210
        mem _18286 = store _18284, val_18244, pf32 252
        (mem, pf32) _18287 = load _18286, val_18215
        pf32 _18303 = extract _18287, qu32 1
        mem _18288 = extract _18287, qu32 0
        pf32 _18304 = add pf32 0, _18303
        (mem, pf32) _18289 = load _18288, val_18222
        pf32 _18305 = extract _18289, qu32 1
        mem _18290 = extract _18289, qu32 0
        pf32 _18306 = add _18304, _18305
        (mem, pf32) _18291 = load _18290, val_18225
        pf32 _18307 = extract _18291, qu32 1
        mem _18292 = extract _18291, qu32 0
        pf32 _18308 = add _18306, _18307
        (mem, pf32) _18293 = load _18292, val_18232
        pf32 _18309 = extract _18293, qu32 1
        mem _18294 = extract _18293, qu32 0
        pf32 _18310 = add _18308, _18309
        (mem, pf32) _18295 = load _18294, val_18235
        pf32 _18311 = extract _18295, qu32 1
        mem _18296 = extract _18295, qu32 0
        pf32 _18312 = add _18310, _18311
        (mem, pf32) _18297 = load _18296, val_18241
        mem _18298 = extract _18297, qu32 0
        pf32 _18313 = extract _18297, qu32 1
        (mem, pf32) _18299 = load _18298, val_18244
        pf32 _18314 = add _18312, _18313
        mem _18300 = extract _18299, qu32 0
        pf32 _18315 = extract _18299, qu32 1
        pf32 _18316 = add _18314, _18315
        mem _18317 = store _18300, n_18302, _18316
        return_18041(_18317, tuple ())

    lambda_18207(mem lambda_18208, qs32 lambda_18209, fn(mem) lambda_18210)
        bool _18211 = lt lambda_18209, qs32 3
        pf32 _18216 = cast lambda_18209
        pf32 _18217 = mul pf32 42, _18216
        br_18033(_18211, if_then_18212, if_else_18227)

    if_else_18227()
        bool _18228 = lt lambda_18209, qs32 5
        br_18033(_18228, if_then_18229, if_else_18237)

    if_else_18237()
        bool _18239 = lt lambda_18209, qs32 6
        br_18033(_18239, if_then_18240, if_else_18243)

    if_else_18243()
        mem _18245 = store lambda_18208, val_18244, _18217
        lambda_18210(_18245)

    if_then_18240()
        mem _18242 = store lambda_18208, val_18241, _18217
        lambda_18210(_18242)

    if_then_18229()
        bool _18230 = lt lambda_18209, qs32 4
        br_18033(_18230, if_then_18231, if_else_18234)

    if_else_18234()
        mem _18236 = store lambda_18208, val_18235, _18217
        lambda_18210(_18236)

    if_then_18231()
        mem _18233 = store lambda_18208, val_18232, _18217
        lambda_18210(_18233)

    if_then_18212()
        bool _18213 = lt lambda_18209, qs32 1
        br_18033(_18213, if_then_18214, if_else_18219)

    if_else_18219()
        bool _18220 = lt lambda_18209, qs32 2
        br_18033(_18220, if_then_18221, if_else_18224)

    if_else_18224()
        mem _18226 = store lambda_18208, val_18225, _18217
        lambda_18210(_18226)

    if_then_18221()
        mem _18223 = store lambda_18208, val_18222, _18217
        lambda_18210(_18223)

    if_then_18214()
        mem _18218 = store lambda_18208, val_18215, _18217
        lambda_18210(_18218)

    lambda_18143(mem lambda_18144, qs32 lambda_18145, fn(mem) lambda_18146)
        pf32 _18152 = cast lambda_18145
        bool _18147 = lt lambda_18145, qs32 3
        pf32 _18153 = mul pf32 42, _18152
        br_18033(_18147, if_then_18148, if_else_18163)

    if_else_18163()
        bool _18164 = lt lambda_18145, qs32 4
        br_18033(_18164, if_then_18165, if_else_18168)

    if_else_18168()
        bool _18170 = lt lambda_18145, qs32 5
        br_18033(_18170, if_then_18171, if_else_18174)

    if_else_18174()
        mem _18176 = store lambda_18144, val_18175, _18153
        lambda_18146(_18176)

    if_then_18171()
        mem _18173 = store lambda_18144, val_18172, _18153
        lambda_18146(_18173)

    if_then_18165()
        mem _18167 = store lambda_18144, val_18166, _18153
        lambda_18146(_18167)

    if_then_18148()
        bool _18149 = lt lambda_18145, qs32 1
        br_18033(_18149, if_then_18150, if_else_18155)

    if_else_18155()
        bool _18156 = lt lambda_18145, qs32 2
        br_18033(_18156, if_then_18157, if_else_18160)

    if_else_18160()
        mem _18162 = store lambda_18144, val_18161, _18153
        lambda_18146(_18162)

    if_then_18157()
        mem _18159 = store lambda_18144, val_18158, _18153
        lambda_18146(_18159)

    if_then_18150()
        mem _18154 = store lambda_18144, val_18151, _18153
        lambda_18146(_18154)

    lambda_18089(mem lambda_18090, qs32 lambda_18091, fn(mem) lambda_18092)
        bool _18093 = lt lambda_18091, qs32 2
        pf32 _18098 = cast lambda_18091
        pf32 _18099 = mul pf32 42, _18098
        br_18033(_18093, if_then_18094, if_else_18104)

    if_else_18104()
        bool _18105 = lt lambda_18091, qs32 3
        br_18033(_18105, if_then_18106, if_else_18109)

    if_else_18109()
        bool _18111 = lt lambda_18091, qs32 4
        br_18033(_18111, if_then_18112, if_else_18115)

    if_else_18115()
        mem _18117 = store lambda_18090, val_18116, _18099
        lambda_18092(_18117)

    if_then_18112()
        mem _18114 = store lambda_18090, val_18113, _18099
        lambda_18092(_18114)

    if_then_18106()
        mem _18108 = store lambda_18090, val_18107, _18099
        lambda_18092(_18108)

    if_then_18094()
        bool _18095 = lt lambda_18091, qs32 1
        br_18033(_18095, if_then_18096, if_else_18101)

    if_else_18101()
        mem _18103 = store lambda_18090, val_18102, _18099
        lambda_18092(_18103)

    if_then_18096()
        mem _18100 = store lambda_18090, val_18097, _18099
        lambda_18092(_18100)

    lambda_18042(mem lambda_18043, qs32 lambda_18044, fn(mem) lambda_18045)
        bool _18047 = lt lambda_18044, qs32 2
        pf32 _18057 = cast lambda_18044
        pf32 _18058 = mul pf32 42, _18057
        br_18033(_18047, if_then_18048, if_else_18063)

    if_else_18063()
        bool _18065 = lt lambda_18044, qs32 3
        br_18033(_18065, if_then_18066, if_else_18069)

    if_else_18069()
        mem _18071 = store lambda_18043, val_18070, _18058
        lambda_18045(_18071)

    if_then_18066()
        mem _18068 = store lambda_18043, val_18067, _18058
        lambda_18045(_18068)

    if_then_18048()
        bool _18050 = lt lambda_18044, qs32 1
        br_18033(_18050, if_then_18051, if_else_18060)

    if_else_18060()
        mem _18062 = store lambda_18043, val_18061, _18058
        lambda_18045(_18062)

    if_then_18051()
        mem _18059 = store lambda_18043, val_18055, _18058
        lambda_18045(_18059)

br_18033(bool br_18034, fn() br_18035, fn() br_18036)
madmann91 commented 6 years ago

Note that it starts to be non-evaluated for n = 5. Everything is OK for n = 0 ... n = 4. For example, when the outer loop is restricted to for n in range(0, 4) @{ /* ... */ }, I get this (the slots are removed later on by mem2reg):

module 'new_pe'

main_3807(mem mem_3808, [pf32]* res_3809, fn(mem, ()) return_3810) extern 
    pf32* n_3819 = lea res_3809, qs32 2
    pf32* n_3823 = lea res_3809, qs32 3
    pf32* n_3816 = lea res_3809, qs32 1
    pf32* n_3812 = lea res_3809, qs32 0
    mem _3814 = store mem_3808, n_3812, pf32 0
    mem _3817 = store _3814, n_3816, pf32 0
    mem _3821 = store _3817, n_3819, pf32 42
    mem _3825 = store _3821, n_3823, pf32 126
    return_3810(_3825, tuple ())
leissa commented 6 years ago

Works here, even for n = 32. Are you sure, you are using my variant where the inner loops also use this

range(0, n, @|i| { /*...*/ });

workaround?

That our mem2reg doesn't remove these slots is a missing feature, as we don't consider leas at the moment. But LLVM saves us in this case.

madmann91 commented 6 years ago

Yes, I am using your variant, with the new syntax for n in range(0, 4) @{ /* ... */ }, which is equivalent (see my commits in Impala). The problem is not the slots, they are fine. The problem is the lt lambda_18044, qs32 3 and friends. This means that the bisection in SmallStack.read has not been eliminated.

leissa commented 6 years ago

Works here - even with the new syntax:

$  impala rodent.impala -emit-thorin -Othorin 
module 'rodent'

main_9775(mem mem_9776, [pf32]* res_9777, fn(mem, ()) return_9778) extern 
    pf32* n_9781 = lea res_9777, qs32 0
    pf32* n_9785 = lea res_9777, qs32 1
    mem _9783 = store mem_9776, n_9781, pf32 0
    pf32* n_9788 = lea res_9777, qs32 2
    pf32* n_9800 = lea res_9777, qs32 5
    pf32* n_9804 = lea res_9777, qs32 6
    pf32* n_9808 = lea res_9777, qs32 7
    pf32* n_9792 = lea res_9777, qs32 3
    pf32* n_9796 = lea res_9777, qs32 4
    mem _9786 = store _9783, n_9785, pf32 0
    mem _9790 = store _9786, n_9788, pf32 42
    mem _9794 = store _9790, n_9792, pf32 126
    mem _9798 = store _9794, n_9796, pf32 252
    mem _9802 = store _9798, n_9800, pf32 420
    mem _9806 = store _9802, n_9804, pf32 630
    mem _9810 = store _9806, n_9808, pf32 882
    return_9778(_9810, tuple ())
madmann91 commented 6 years ago

By the way, this new commit you added (05e40c1f9755011df16a17d397bb137ba6053535) makes the PE diverge on one variant in Rodent, even though this variant was working previously. The divergent loop iswhile (!queue.empty()) {inpartial_evaluation.cpp`. It seems that it keeps processing continuations that come from PE (generated during mangling?).

madmann91 commented 6 years ago

I will give it a shot again.

leissa commented 6 years ago

Yes, most likely - but this doesn't mean it's a bug. We should get pe_info back. It really helps debugging and understanding what's going on.

madmann91 commented 6 years ago

I just tested it on my laptop. It seems to work there. I guess this is fixed then. I'll try to find a small example for the broken Rodent variant. It's basically the same code, so I do not get why this is not working now.

leissa commented 6 years ago

My last commit 66a81419b252c0fed2fddd30321bb2c8100b89d3 shouldn't effect anything, btw. It's currently not used.

madmann91 commented 6 years ago

I have narrowed the problem down. Here is a piece of code that does not go through:

fn @sort(len: int) -> () {
    if len > 1 {
        sort(len / 2);
    }
}

fn @eight(k: int) -> int {
    if k == 0 {
        8
    } else {
        eight(k - 1)
    }
}

fn main() -> () {
    sort(eight(1))
}

The good thing about this new PE mode is that it is more local, which makes it easier to find small snippets that reproduce a bug. Would adding annotations at the call site help here?

leissa commented 6 years ago

Just by looking at the Code: I guess the problem is the @ for sort which ist recursive

madmann91 commented 6 years ago

But here it works if you call sort(8). And eight(1) evaluates to 8 when used alone. I suspect that the return continuations need to be inlined in order to specialize a function with the result of another function. Here it does not work because eight() is recursive. You can achieve the same effect with:

fn @sort(len: int) -> () {
    if len > 1 {
        sort(len / 2);
    }
}

fn @eight(k: int) -> int {
    8
}

fn main() -> () {
    sort((eight(1) + eight(1)) / 2)
}

This prevents inlining of eight(), because there are two different calls to it. Therefore, it also breaks.

madmann91 commented 6 years ago

This is a bug. The pe_profile of eight() is set to false for some reason. Will investigate.

leissa commented 6 years ago

No, it's a different issue. We fold eight but don't propagate to it's return continuation.

madmann91 commented 6 years ago

Parameter elimination is now properly propagating the profile so that is fixed.

leissa commented 6 years ago

ah yes :)

leissa commented 6 years ago

This was definitely an issue, but there was still the problem I mentioned earlier that we didn't propagate into the return continuation. 246ec003a53dc8d7251e118bf0bc5b02b6b24cb2 should fix this.

madmann91 commented 6 years ago

The snippet with eight as a recursive function still does not work (with the latest commits).

leissa commented 6 years ago

I'm working on it...

leissa commented 6 years ago

The problem is that

fn(a: int, ret: fn(int)) @(true, true)
    ...

and some call

f(23, k)

will specialize to this:

k(42)

and all continuations nested inside of k don't see 42 but k's parameter instead.

After, clean up this problem is gone - but this is too late. Also just checking at the call-site whether a continuation only has one user isn't enough (see 246ec00) because we might end up with dead/unreachable users of an callee.

So, in c6b084f I annotate continuations which gets specialized with an all-true profile if they only have one user. But maybe this is too aggressive?

madmann91 commented 6 years ago

I think it is not aggressive enough ;) It does not work with non tail-recursive functions:

fn @(?i) ilog2(i: int) -> int {
    if i <= 1 {
        0
    } else {
        ilog2(i / 2) + 1
    }
}

fn @sort(len: int) -> () {
    if len > 1 {
        sort(len / 2)
    }
}

fn main() -> () {
    sort(1 << ilog2(8))
}
leissa commented 6 years ago
fn foo(/*...*/) -> int { /*...*/ }

fn @bar(g: fn(/*...*/) -> int) -> int {
    //...
    let x = foo();
    //...
    let y = foo();
    //...
}

//...
bar(foo)

here we would specialize foo twice because at the moment we pass foo to bar, foo only has one user - sth we may do not want...

madmann91 commented 6 years ago

How about applying this only to basic blocks, not functions? After specialization, the return continuation is usually a basic block. This would take care of this case, at least, but correct me if I am wrong.

leissa commented 6 years ago

Ok, the log example doesn't work because the use counter may still contain dead stuff even if we count before folding.

I think the best solution is to always annotate return-continuations built in impala with an all-true profile. In thorin we don't know which continuations are there on purpose and which ones are there by accident, if you know what I mean.

leissa commented 6 years ago

see my last commit in impala and thorin

madmann91 commented 6 years ago

Can't we just annotate every call to a continuation argument that happen inside the folded continuation as to be inlined (with the new Run PrimOp)? I suppose your solution would imply that:

fn foo(i: int, ret: fn (int) -> !) -> !

does not behave like:

fn foo(i: int) -> int
leissa commented 6 years ago

No, I mean that

let x = foo(i);

is different from

foo(i, k)

because there might be a second call

bar(j, k)

and in this case I guess we don't want to specialize the residual call to k?

Another option would be to simply look for these things beforehand, count the uses and set an all-true profile. Then, do the PE.

Sorry, I don't get your suggestion. Can you elaborate a bit more?

madmann91 commented 6 years ago

Well, the problem with your previous approach is that it relied on the number of uses to perform inlining. Disregarding the fact that this number is incorrect, this might be too aggressive as you pointed out. In our case, we are just interested at the uses happening at one particular call site, not every use. Say we have a continuation f, on which we are trying to perform PE:

f(..., ret: fn (int)):
  ...

  g(...):
    ret(...)

At the call site, we have something like:

call_to_f(...):
  f(..., f_cont)

f_cont(i: int):
  ...

We only want to eat calls to the continuation f_cont that will happen inside the specialized continuation generated from dropping the call to f. Some other use of f_cont should be kept intact. In this example, we want to get a rid of the call to f_cont in dropped_g:

call_to_f(...):
  ...

dropped_g(...):
  f_cont(...)

f_cont(i: int):
  ...
madmann91 commented 6 years ago

This transformation should be safe when f_cont is a basic block.

leissa commented 6 years ago

But even inside f we can have multiple ret calls which might end up in multiple residual calls. And in this case we don't want to inline, right?

f(..., ret: fn(int))
    br(..., T, F)
T()
    ret(23)
F()
    ret(42)

Now, suppose we have

cur(...)
    f(..., k)

and after dropping:

cur(...)
    f'()
f'()
    br(..., T', F')
T'()
    k(23)
F'()
    k(42)

Now, specializing those k-calls might be a bad idea.

madmann91 commented 6 years ago

I think it is fine if you do this only when k is of order 1 (a basic block). After all, the return continuation is here as an administrative term: It only says where to jump. By inlining it, we are merely performing the jump (back to the call site).

leissa commented 6 years ago

But in my example above, we would double the whole program. One with 23, one with 42.

leissa commented 6 years ago

The precise way would be actually to let the programmer decide whether he wants the ret-call to specialize or not - maybe using certain conditions...

madmann91 commented 6 years ago

I see your point. Maybe we should do that when f_cont is a simple block (only passing arguments through, and only one use within the dropped call)? There would be no risk of program explosion and that would solve our problems in a transparent manner.

leissa commented 6 years ago

So, this super recursive thing, we talked about didn't work out as expected. It is doable but requires more work, as we encountered a couple of problems. However, I implemented another trick in 83c517d93f0372f44559d0258fa2ad14df9dd52d and removed the workaround in AnyDSL/impala@ce4a48da393b4eadc472a534479a66beab4a9db7. So far, it seems like this works :)

leissa commented 6 years ago

The trick in 83c517d still isn't enough. My last commit is super bruteforce and slow but works for now. This recursive stuff should work - once we have the implementation correct - but I'll return to it after vacation.