Frege / frege

Frege is a Haskell for the JVM. It brings purely functional programing to the Java platform.
https://github.com/Frege/frege/wiki/_pages
Other
3.64k stars 144 forks source link

tail call optimization with limited stack depth #245

Closed minad closed 8 years ago

minad commented 8 years ago

Hi,

I was playing around a bit with Frege since I wanted to see how to run a functional language on the jvm. I wonder if it would be possible to implement a different tail call scheme as described here http://stackoverflow.com/questions/4527795/handling-stackoverflow-in-java-for-trampoline (class CPS3).

I think it works similar to the chickenscheme trampoline where the stack is used up to a certain level and at some points thunks are created.

public static Thunk even(final int n,final int depth) {
    if (depth >= 1000) {
        return new Thunk() {
            public Thunk compute() {
                return even(n, 0);
            }
        };
    }
    if (n == 0) return new Thunk(true);
    else return odd(n-1, depth+1);
}

Thunk.force(even(n, 0))

When I run TailCalls.fr it needs about 2s, and with this optimizations only about 0.2s.

However the question is also how this would play together with laziness, but for strict functions (or after strictness analysis) this would make sense I guess.

Ingo60 commented 8 years ago

I wonder how a lambda based solution would compare (i.e. let Thunk be Callable).

That being said, it would be not too difficult to implement such a scheme. First, the compiler groups all functions in groups, where the functions in one group call each other recursively. If the sze of a group is 1, this means there is no mutual recursion. If it is 2, we have something like even/odd. And if it is more, we have something even more complicated.

Anyway, with regard to static function calls, the compiler knows where deep stacks are looming and can returns Thuk instead of values, avoiding tail calls. Now, we could have a private counter for each mutual recursive group and do something like:

return (++counter & 0x3ff == 0) ? (() -> odd(n-1)) : odd(n-1)

(and similar in all other tailcalls that are done in the group, which do all have the same return type like Callable)

This would then only return a Callable every 1024th time.

minad commented 8 years ago
  1. I tried java 1.8 lambdas and they are in general much better. You don't generate many inner classes and the performance is the same or even better. I would recommend to switch everywhere from anonymous classes to lambdas if possible.
  2. Is it really that simple with those function groups - Couldn't you create lambda functions which are passed around to higher order functions and perform deep tail calls in such a way that this static analysis fails and you cannot create groups?
  3. What about multithreading and such a counter? Create an extra instance of the group for every thread?
minad commented 8 years ago

The following scheme works pretty well:

public interface Thunk {
    Object call(Tail tail);
}

public class Tail {
    private static final int maxDepth = 500;

    public int depth;

    public Thunk thunk;

    public boolean limit() {
        return ++this.depth >= maxDepth;
    }

    public <T> T run(int d, T ret) {
        this.depth = d;
        while (this.thunk != null) {
            Thunk fn = this.thunk;
            this.thunk = null;
            ret = (T) fn.call(this);
            this.depth = d;
        }
        return ret;
    }
}

public static boolean even(int n, Tail tail) {
        if (tail.limit()) {
                tail.thunk = t -> even(n, t);
                return false;
        }
        return n == 0 ? true : odd(n - 1, tail);
}

public static boolean test() {
        Tail tail = new Tail();
        return tail.run(tail.depth, even(200000000, tail));
}
Ingo60 commented 8 years ago

Well, there is an even easier solution that doesn't even need a counter.

It is applicable if the mutual recursions appear only in the tail.

The really problematic cases are recursions that DON'T happen in the tail.

A prototypical example is foldr.

Another problematic case are Stackoverflows stemming from thunk nesting. The classic here is foldl, and unfortunately, it can't always be cured by using foldl'

Consider

import Data.Monoid
import Data.wrapper.Num
foldl' (<>) (map (Just . Sum) [1..1000000])

This looks artificial but once the FTP proposal is in place such things can happen easily.

I have no really convincing idea of how to compute this on the JVM, yet.

minad commented 8 years ago

Well, there is an even easier solution that doesn't even need a counter.

You mean by inlining mutual recursive functions and converting to a loop? But that needs quite a lot of code rewriting?

The really problematic cases are recursions that DON'T happen in the tail.

What kind of transformation could be done in that case? What is ghc doing?

Ingo60 commented 8 years ago

You mean by inlining mutual recursive functions

Quite so. We can transform to:

evenodd 0 n = -- right hand side of original even
evenodd _  n = -- right hand side of original odd
even  = evenodd 0
odd  = evenodd 1

And in the right hand sides of evenodd, replace every occurance of even with evenodd 0 and of odd with evenodd 1

This is only slightly more code than the original.

What kind of transformation could be done in that case?

I'd be quite happy if I knew this.

What is ghc doing?

AFAIK (but try it out for yourself), it does nothing about it! It just relies on its runtime system to provide as much stack space as is needed. And this usually works, because unlike in the JVM the stack is not fixed.

Therefore, the following is possible:

ingo@delluntu:~/Frege/frege$ ghci 
GHCi, version 7.6.3: http://www.haskell.org/ghc/  :? for help
Loading package ghc-prim ... linking ... done.
Loading package integer-gmp ... linking ... done.
Loading package base ... linking ... done.
Prelude> foldr (+) 0 [1..100000000]
*** Exception: stack overflow
Prelude> 

Granted, it makes a difference whether you get that stack overflow at 10_000 or 1_000_000_000 But the fundamental problem is the same.

minad commented 8 years ago

Hi!

And in the right hand sides of evenodd, replace every occurance of even with evenodd Even and of odd with evenodd Odd

This is only slightly more code than the original.

Yes sure. But maybe there are cases where this cannot be done. Maybe for higher order functions between modules? Then this example given above could provide a fast solution.

What is ghc doing?

AFAIK (but try it out for yourself), it does nothing about it! It just relies on its runtime system to provide as much stack space as is needed. And this usually works, because unlike in the JVM the stack is not fixed.

Ah I see, the jvm stack is much more limited.

What about the ugly workaround of taking a worker from a thread pool if the stack size becomes critical? Would that be possible?

Ingo60 commented 8 years ago

Maybe for higher order functions between modules?

First off, we must recognize, that lambdas in Haskell and Frege cannot be recursive (well, they can, when you use tricks like type fixpoints, etc., but I think we can agree that this is of no practical relevance).

The reason is that any Y combinator must use the Omega combinator

Ox = x x

and this would not type check, as you can easily convince yourself.

Hence, yes, we could make even a higher order function and pass odd to it. But we cannot pass it an odd that itself got passed even. Would simply not typecheck!

Hence, there are only three possible cases when we see a tail call in some function:

  1. call a function that is known to be in the same group (mutual recursion)
  2. call some other, unrelated function
  3. call an unknown function passed as an argument. But this cannot be a higher order function that leads to recursion. ( Unless someone used fix f = f (fix f) but again: no practical relevance)

worker from a thread pool if the stack size becomes critical? Would that be possible?

Answer from radio Yerevan: in priciple yes, but it doesn't solve the problem. Since there are so many threads only, and each of them has a fixed stack size. Let's say we allow ourselves 100 threads. Then we can compute the sum of one million numbers .....

I want a solution where we can easily compute the sum of twenty billion numbers (with foldr that is, it works with fold on plain numbers), without the user noticing that something extraordinary happens.

minad commented 8 years ago

Thx for the explanations! It would be a bit like radio yerevan, I agree ;)