google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

[Proposal] Define static arguments using typing annotations #10476

Open JeppeKlitgaard opened 2 years ago

JeppeKlitgaard commented 2 years ago

I am really enjoying jax, but have found the ergonomics of declaring function arguments as static for use in jit, pjit, and pmap to be a little cumbersome and error-prone when refactoring.

In many (but not necessarily all) cases it is useful and intuitive to declare some arguments as static at the point of function declaration rather than later when mapping or jitting the function. This contrasts axis-mapping, which is always most intuitively done at the point of transforming the function using xmap, vmap, or pmap.

Using PEP 593 (Python 3.9+, backported in typing_extensions) it should be possible to declare static arguments in the function signature.

Example

We have some function, which has both dynamic and static variables. Here c_func is "embarrasingly static" in that we strictly require it to be static at compile-time.

def make_hypercube(a, b, c: float, c_func: Callable[[float], float]):
    arr = jnp.empty((3, 3))

    # Some other implementation details here
    arr += a
    arr *= b

    arr **= c_func(c)
    ...

    return arr

Currently to pmap over a and b, one would do:

pmake_hypercube = pmap(make_hypercube, in_axes=(0, 0, None, None), static_broadcasted_argnums=(3,))

This is not ideal in that changing the signature of make_hypercube would require us to change our transformation, bearing in mind that these may be in very different parts of the code. Additionally, when functions have many arguments, changing the argnums tuple during refactoring or experimentation becomes needlessly cumbersome and error-prone.

I would propose adding a PEP 593 annotation, which would complement (not replace) the current method of marking arguments as static. In this case our implementation becomes:

from jax import pmap, Static

def make_hypercube(a, b, c: float, c_func: Static[Callable[[float], float]]):
    arr = jnp.empty((3, 3))

    # Some other implementation details here
    arr += a
    arr *= b

    arr **= c_func(c)
    ...

    return arr

pmake_hypercube = pmap(make_hypercube, in_axes=(0, 0, None, None))

Implementation

While the PEP and Python documentation go into more details on this, such an annotation would be implemented as:

from typing import Annotated, TypeVar,, get_type_hints, get_args

T = TypeVar("T")

StaticAnnotation = object()
Static = Annotated[T, StaticAnnotation] 

def f(a: str, b: Static[dict[str, float]]):
    ...

def is_static(func, arg) -> bool:
    return StaticAnnotation in get_args(get_type_hints(func, include_extras=True)[arg])

print(is_static(f, "a"))
> False
print(is_static(f, "b"))
> True

Other use-cases

A similar situation might apply to non-differentiable arguments: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#handling-non-differentiable-arguments

Which might have an annotation NoDiff.

TL;DR

Introduce Static annotation to specify static arguments in function signatures. Should work with jit, pjit, pmap.

JeppeKlitgaard commented 2 years ago

I would be happy to do the implementation and documentation for this if I get a few pointers in the right direction. typing_extensions is already a dependency, so there shouldn't be any additional dependencies needed.

Other libraries have seen great success with adopting the new type hinting language features (pydantic, fastapi, ...). It should be noted that the Annotated 'pass-through'-style type-hints haven't yet seen much adoption, but I think for cases like this one they provide the most Pythonic experience.

If the user does not wish to use type-hinting, but still use the Static annotation, one can simply bind the generic to Any with:

def func(a, b: Static[Any]):
    ...

# Or to reduce verbosity
S = Static[Any]

def func(a, b: S):
    ...
patrick-kidger commented 2 years ago

It's not exactly the same, but the Equinox project offers something very similar to what you're requesting.

Equinox offers a few "filtered transformations" that take a rule for determining how the arguments should be partitioned, and then uses that instead of the manually-specified approach used at the moment. (Equinox also implements some neural networks etc, but that's an unrelated part of the library.)

The main difference to your approach is that the partitioning is handled using the types of the arguments actually provided, rather than the types provided via the annotation. (And if you really want I think it should probably be possible to pass a filter_spec that does the annotation-like behaviour, so you can have that too.)

filter_{jit,grad,value_and_grad,custom_vjp} currently exist; the next release will also include filter_{vmap,pmap}.


Side notes.

  1. Equinox's approach is actually more powerful: it allows for splitting each individual argument into dynamic/static/etc. pieces, rather than each entire argument having to be the same choice of dynamic/static/etc. This within-argument splitting comes up quite naturally in quite a lot of contexts, e.g. we might store the data for an MLP as [[weight0, bias0], activation, [weight1, bias1]], in which the weights and biases are JAX arrays but the activation is some Python callable. It would then be neatest to pass all of this data as a single argument.

  2. Your example c_func need not be "embarrasingly static". It's perfectly possible to have callable dynamic objects. Going back to Equinox: try passing c_func as an instance of equinox.nn.Linear. This is because it's possible to have custom PyTrees can (a) contain only dynamic JAX arrays and (b) can implement a __call__ method.

JeppeKlitgaard commented 2 years ago

I really like the look of Equinox, which seems to have a lot of improvements to the ergonomics of JAX – eqx.Module seems similar in spirit to the idea of a dataclass that is also automatically a PyTree, which seems to be a highly requested feature (cf.: #9662, #1808, #2371)

Some of these discussions predate wider adoption of PEP 593.

I think my attempt at trying to coin a phrase might need some elaboration – "embarrasingly static" was intended to refer to arguments that are always statically compiled such as a python function or any other object that cannot be replaced by a Tracer and jitted. Since these attributes are necessarily static in all cases, they can be deduced as static when defining the function.

An important exception to this could be a function that takes an argument, arr, that can be either a numpy or JAX array. In this case we might want to specify that arr is static at the point of calling jax.jit in the case where arr is a numpy array.

The problem

In native JAX we currently need to carry around a 'schema' that we can use to define which arguments are static. This schema is static_argnums in the case of jax.jit and broadcasted_static_argnums in the case of jax.pmap. It is not very ergonomic to type out at all and is prone to breakage when refactoring.

The solution by Equinox

(as I understand it, please correct)

Using Equinox this situation is improved by either:

Filter spec

Making a custom filter_spec and using filtered transformations filter_jit, filter_pmap.

This solution still requires creation a schema-like object and passing it to the filtered transformation. It is arguably more ergonomic.

Using is_array filter spec

The default filter spec turns all non-arrays into static arguments. This is essentially exactly what I wanted, but I don't really like the fact that it 'hides' from the user which arguments are statically compiled and which are traced. Since it is done using a runtime instance check, it won't fail loudly in cases where an argument inferred to be static is passed in a place where the user thinks it is dynamic. This could silently lead to excessive re-compilation.

The solution using annotations

YouJiacheng commented 2 years ago

Sounds Great!❤️

chrisflesher commented 1 year ago

Can you use static_argnames instead of static_argnums to help with refactoring?

I've been using Equinox and really like it so far, it has some decent default settings for filter_jit and filter_grad so 90% of the time I don't to worry about the filter_spec stuff.

patrick-kidger commented 1 year ago

Folks here may be interested in some changes we're proposing for Equinox. Namely, to simplify the filter_jit API to just filter_jit(fn, donate), with all JAX/NumPy arrays dynamically traced and everything else held static. (And the donate flag handling donation be behaviour, of course.)

This is in recognition of the fact that all arrays have to be treated dynamically, and most Python types have to be treated statically, so really the only room for non-default behaviour (i.e. treating bool/int/float/complex dynamically) can be handled just by wrapping these types with jnp.asarray yourself.

The PR for this is here: https://github.com/patrick-kidger/equinox/pull/235 The code is here: https://github.com/patrick-kidger/equinox/blob/ee7f91a1c422abc0e502c6dd467045e79c1717bc/equinox/jit.py#L87

I'd welcome any feedback on this change.

chrisflesher commented 1 year ago

@patrick-kidger The proposed changes seem like it would make things easier to use... was interested in how this would handle typing.NamedTuple, I'm guessing those would be treated as static by default?

Would the filter_spec stuff still be required for making equniox.Module array attributes static? I have an equinox.Module that has adefault_value attribute that I need to always keep static despite being a jax.Array type. Currently I'm using the filter_spec functionality to do this. With the new API changes would this functionality remain the same?

patrick-kidger commented 1 year ago

Regarding NamedTuple: as this a pytree node, then it will be unpacked, and each of its elements treated individually.

Regarding static Arrays: this should be impossible right now. Static inputs must be hashable, and jax.Arrays aren't hashable. So I'm not sure what you're currently doing here.

chrisflesher commented 1 year ago

Here are some more details. I have a voxel map class.

class VoxelMap(equinox.Module):
    """A sparse voxel map.

    Attributes:
        resolution: Voxel width [meters]
        indices: Dense 3D array relating voxel index to values. Populated voxels are set to
            integers in the range [0, num_populated_voxels), unpopulated are set to -1
        values: Dense 2D array of voxel values (num_populated_voxels, num_elements_per_voxel)
        default_value: Default voxel value (num_elements_per_voxel)
    """

    resolution: float
    indices: jax.Array
    values: jax.Array
    default_value: jax.Array

The equinox.filter_grad function updates everything properly by default except for the default_value attribute. I can't just set its type to float because it needs to handle more than one element sometimes. Currently I'm using filter_spec functionality to ignore this attribute but is a bit complicated to read. I wouldn't mind baking it into the class somehow, are there other options to consider? Maybe baking it into the class using a decorator?

patrick-kidger commented 1 year ago

Ah, we're (partly) talking past each other.

The proposed change is to filter_jit. The function you're using is filter_grad. Right now that's staying the same.

However -- you're right that the filtering argument to filter_grad is ugly! So first of all, it is actually possible to avoid using it, and bake this into the class:

class Buffer(eqx.Module):
    array: jax.Array

    def get(self):
        return lax.stop_gradient(self.array)

class VoxelMap(eqx.Module):
    ...
    default_value: Buffer

    def __init__(self, ..., default_value):
        self.default_value = Buffer(default_value)

    def __call__(self, ...):
        default_value = self.default_value.get()
        ...

As it happens, we are also contemplating removing the argument of equinox.filter_grad. (We just haven't done it yet.) Rationale:

If you are one of the relatively few users using the filtering argument of eqx.filter_grad, then I'd welcome feedback on that proposed change as well!

chrisflesher commented 1 year ago

Ah, thanks lax.stop_gradient seems like the best solution.

JesseFarebro commented 1 year ago

At one point I was also thinking about this feature, but after more discussion internally it seems like an implementation based on annotations would be error-prone and hard to put guard rails around.

The issue is that you could have other decorators that are misbehaved or perform arbitrary transformations on the arguments that make it so Jax can't inspect the annotations.

For well-behaved decorators that use functools.wraps and typing.ParamSpec things would work out nicely, but in general there could be lots of silent failure cases and there's no way to detect a misbehaved decorator.

JeppeKlitgaard commented 1 year ago

@JesseFarebro I admittedly haven't messed around enough with PEP 593 to have a feel for how often decorators mangle the annotations, but I understand if Jax doesn't want to introduce what would admittedly be hard-to-debug cases by introducing this syntax. I can't gauge whether a warning in the documentation might be deemed sufficient to go ahead with a proposal like this anyway – the use-case matches the one mentioned in the PEP very well, so presumably there is some enthusiasm for syntax like this being supported in the Python community.

A large amount of the 'frustration' or fiddliness around annotating arguments in Jax currently could also be alleviated by #10614, which is independent of this proposal.