beartype / plum

Multiple dispatch in Python
https://beartype.github.io/plum
MIT License
505 stars 23 forks source link

JIT + dispatch #154

Open nstarman opened 1 month ago

nstarman commented 1 month ago

Starting from the discussion in #152.

This is my quick thoughts on a suggestion for the docs and potentially an improvement to dispatch.

import jax
from jaxtyping import Int, Array, Float, Float64
from plum import dispatch

First let's see the speed bump

@dispatch
def nojit_func(a: int | Int[Array, ""]) -> int | Int[Array, ""]:
    return 2 * a

# %timeit func(3)
# 8.91 µs ± 33 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
@jax.jit
@dispatch
def jit_func(a: int | Int[Array, ""]) -> int | Int[Array, ""]:
    return 2 * a

# func(3)   # trigger jit
# %timeit func(3)
5.03 µs ± 19.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Noticeable even for this trivial operation. 🎉

Now let's consider the problem of jitting a dispatch

@jax.jit
@dispatch
def func(a: int | Int[Array, ""]) -> int | Int[Array, ""]:
    return 2 * a

func(3) 
# Array(6, dtype=int64, weak_type=True)

try:
    out = func(3.0)
except Exception as e:
    print(e)
else:
    print(repr(out))
# `func(Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)` could not be resolved.

@jax.jit
@dispatch
def func(a: float | Float[Array, ""]) -> float | Float[Array, ""]:
    return 2.1 * a

func(3.0)
# Array(6.3, dtype=float64, weak_type=True)

@jax.jit
@dispatch(precedence=1)
def func(a: Float64[Array, ""]) -> Float64[Array, ""]:
    return 2.2 * a

func(3.0)
# Array(6.3, dtype=float64, weak_type=True)  # THIS IS WRONG!

func.clear_cache()

func(3.0)
# Array(6.6, dtype=float64, weak_type=True)   # THIS IS RIGHT!

We need to clear the cache for this to work correctly. When @dispatch is called it could check if there are any other dispatches and then it could try clearing the cache if there are. Clearing the cache wholesale is not ideal, but until/unless there's a targeted way to clear only conflicting dispatches, then it's necessary. If there were a jit backend agnostic way of doing that (since JAX is not a dependency) that would be great.

PhilipVinc commented 1 month ago

A way to know if the function is being jitted is to check if the array returned by jax is a Tracer

def currently_jitting():
    return isinstance(jnp.array(1) + 1, jax.core.Tracer)

An hacky way to avoid depending/loading jax unless needed is to just check if jax is in sys.modules.

In [1]: import sys
In [2]: "jax" in sys.modules
Out[2]: False

In [3]: import jax
In [4]: "jax" in sys.modules
Out[4]: True

And so if you verify that Jaxis there, it was already loaded so you can safely load it as well...

Then, if you know if the dispatch happened during a jit-compilation, you can add a hook to Function.dispatch that at the next new method registration we must clear jax cache.

However, the only way I can think of making this work, is to clear the global cache of jax, not the one of a particular method. To clear the cache of a single method, we should find a way to identify which one we are currently jitting. I'm sure there are whacky ways to do it, but they might break easily...

nstarman commented 4 weeks ago

@wesselb, what do you think of adding a short example to the docs? So long as all dispatches are created / imported at the beginning it works just fine. Without solving the jit cache issue, it would be good to just have a warning banner explaining when things can go wrong.

wesselb commented 4 weeks ago

@nstarman Adding a short example to the doc that explain this, the massive performance benefits, and the caching problems would be super nice. :) I really like this.

If we don't want a dependency on jax, to achieve comparable speed-ups, perhaps we can have context manager that "remembers" all dispatch decisions:

from plum import Cache

with Cache() as cache:
    # Run a function with complicated dispatch. Internally, Plum saves the methods
    # that are chosen by the resolver in `cache`.
    f(arguments)

with cache:
    # Instead of running the resolver, choose the methods stored in `cache`. This
    # should eliminate all dispatch overhead.
    f(argument)

How would something like that sound?

nstarman commented 3 weeks ago

For JAX, I was thinking just to "structurally" identify jax. jax.jit produces as jaxlib object with attribute clear_cache. That's sufficient to identify JAX jitted functions. See https://github.com/beartype/plum/pull/158.