jank-lang / jank

The native Clojure dialect hosted on LLVM
https://jank-lang.org
Mozilla Public License 2.0
1.7k stars 50 forks source link

Use a meta-based approach to inlining #81

Open jeaye opened 4 months ago

jeaye commented 4 months ago

Background reading

I covered how Clojure's polymorphic arithmetic works, as well as the inlining that we're doing, here: https://jank-lang.org/blog/2023-04-07-ray-tracing/

jank's approach to polymorphism has changed since, to no longer use inheritance, but the overall concept is the same and the inlining hasn't changed at all.

Why inline?

In short, calling functions through vars is heavy weight. It requires fetching the var's value every time (since the value can change at any time) and fetching a var's value requires synchronization. After that, the function itself needs to be called, which is extra work. Inlining allows the compiler to just copy/paste the function's body into the callsite, so we don't need to fetch the var's value and we don't need to call the function.

The big downside of this is that we no longer can replace that function and have all of the existing call sites use the new function, since they may be inlined. So inlining only makes sense when we're doing AOT builds and not looking to replace functions later.

However, we inline, in jank, for a second reason: avoiding boxing. When working with math, especially, we want to avoid boxing as much as possible. jank has added a couple of meta keywords related to this, such as :supports-unboxed-input? and :unboxed-output? and they tie directly to inlined function calls. The reason is that jank functions, defined in jank, require that every input/output is boxed. But we want to avoid that where possible, with math. This is why inlining (for the purpose of unboxing) can't just be the same as copy/pasting the function body, since the code we want to run will not involve any make_box calls and can't assign to __value (which is a box).

What Clojure does

Looking at clojure.core, we can easily find some examples of what Clojure supports for inlining:

(defn count
  "Returns the number of items in the collection. (count nil) returns
  0.  Also works on strings, arrays, and Java Collections and Maps"
  {
   :inline (fn  [x] `(. clojure.lang.RT (count ~x)))
   :added "1.0"}
  [coll] (clojure.lang.RT/count coll))

In this case, calls to (count foo) will actually be expanded out (like macro expansion) to (. clojure.lang.RT (count ~x)). Notice the syntax quoting and unescaping. This is just like macro expansion.

When a function has multiple arities, Clojure allows for specifying which arities can be inlined. It does that by using another :inline-arities key, which specifies a function which can be called with the arity number. Idiomatically, sets are used often.

; Using a set to say only arity 2 can be inlined.
(defn <
  "Returns non-nil if nums are in monotonically increasing order,
  otherwise false."
  {:inline (fn [x y] `(. clojure.lang.Numbers (lt ~x ~y)))
   :inline-arities #{2}
   :added "1.0"}
  ([x] true)
  ([x y] (. clojure.lang.Numbers (lt x y)))
  ([x y & more]
   (if (< x y)
     (if (next more)
       (recur y (first more) (next more))
       (< y (first more)))
     false)))

; But Clojure also defines some other functions which are more flexible.
(defn ^:private >1? [n] (clojure.lang.Numbers/gt n 1))

(defn +
  "Returns the sum of nums. (+) returns 0. Does not auto-promote
  longs, will throw on overflow. See also: +'"
  {:inline (nary-inline 'add 'unchecked_add)
   :inline-arities >1?
   :added "1.2"}
  ([] 0)
  ([x] (cast Number x))
  ([x y] (. clojure.lang.Numbers (add x y)))
  ([x y & more]
     (reduce1 + (+ x y) more)))

What jank does

jank currently hard-codes this inlining, during codegen. We say "if we have a call, two params, and the var name is clojure.core/+, replace it with this". This is what we'll want to replace.

The code for that starts here: https://github.com/jank-lang/jank/blob/528d53732e9776d5f771f4462b3de55d5b0c7d04/compiler%2Bruntime/src/cpp/jank/codegen/processor.cpp#L675

What jank should do

We have three goals here.

  1. Support normal inlining
  2. Support unboxed inlining
  3. Control both of those per-arity

Support normal inlining

We should do exactly as Clojure does here, having normal inlining effectively work like macro expansion. We'll use the :inline and :inline-arities keys. However, having to duplicate the function body all of the time is a chore, so let's also support :inline true, which will just use the function body from the correct arity. Finally, let's imply :inline true if :inline-arities is present and the function returns true for that arity. This means normal inlining can happen in three ways:

  1. The :inline key has a fn which returns a new list of data/code
  2. The :inline key is true, which will encourage the compiler to inline that fn (the compiler may not do it and also the compiler may inline a fn which doesn't have :inline set)
  3. The :inline-arities key has a function which is used to control which arities are inlined. By default, it is identity, which will be true for all

Note that the inputs to the :inline fn are always what was present at the call site, just like a macro.

For now, if there is an :inline key (or :unboxed-inline key), let's always inline. We can be smarter about it in the future.

Support unboxed inlining

Instead of this:

(defn
  ^{:arities {1 {:supports-unboxed-input? true
                 :unboxed-output? true}}}
  sqrt [o]
  (native/raw "__value = make_box(std::sqrt(runtime::detail::to_real(~{ o })));"))

And this:

        else if(ref->qualified_name->equal(runtime::obj::symbol{ "clojure.core", "sqrt" }))
        {
          format_elided_var("jank::runtime::sqrt(",
                            ")",
                            ret_tmp.str(false),
                            expr.arg_exprs,
                            fn_arity,
                            false,
                            box_needed);
          elided = true;
          ret_tmp = { ret_tmp.unboxed_name, box_needed };
        }

Let's do this:

(defn
  ^{:unboxed-inline (fn [o]
                      (str "jank::runtime::sqrt(" o ")"))
    :arities {1 {:supports-unboxed-input? true
                 :unboxed-output? true}}}
  sqrt [o]
  (native/raw "__value = make_box(std::sqrt(runtime::detail::to_real(~{ o })));"))

The :unboxed-inline meta will have a function which returns the string that gets replaced in. The function will take the same arity, so that function can support multiple arities which take different values. The inputs of that function will be the handle to the argument expression.

The same :inline-arities key should apply here.

If a function has both :inline and :unboxed-inline, we need to choose the correct one based on whether or not our inputs are boxed and we need a boxed output.

  1. If our inputs are unboxed and we don't need a boxed output, use :unboxed-inline
  2. If our inputs are unboxed and we do need a boxed output, use :unboxed-inline and wrap it in make_box
  3. If our inputs are boxed, regardless of whether we need a box, use :inline

Codegen

Let's take our sqrt example and work it through. Let's say we have a call like this:

(let [a 1.0
      s (sqrt a)]
  (+ a s))

For how it works now, in this code, a is unboxed, s is unboxed, and the call to + will be unboxed (and then wrapped in a make_box) since let is expecting a boxed value out. Let's see the codegen (again, for how it works now).

jank::runtime::object_ptr call() final
{
  using namespace jank;
  using namespace jank::runtime;
  jank::profile::timer __timer{ "repl_fn" };
  object_ptr const repl_fn{ this };
  object_ptr let_7{ jank::runtime::obj::nil::nil_const() };
  {
    {
      auto const a_2(const_1__unboxed);
      auto const call_8(jank::runtime::sqrt(a_2));
      {
        auto const s_4(call_8);
        auto const call_9(jank::make_box(jank::runtime::add(a_2, s_4)));
        let_7 = call_9;
      }
    }
  }
  return let_7;
}

How it will work with :unboxed-inline will be exactly the same, but that's because a_2 will be passed in as a parameter to the :unboxed-inline function, so the string returned from the function is "jank::runtime::sqrt(a_2)".