google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.95k stars 2.75k forks source link

add `is_tracer`, `assert_concrete` functions to API #15625

Open JeppeKlitgaard opened 1 year ago

JeppeKlitgaard commented 1 year ago

I find that I sometimes want to make sure a jitting fails early when being jitted with a tracer where the function requires a concrete value, rather than giving a lengthy stack trace that obfuscates what went wrong.

For this, I am using the following methods and thought it might be convenient and instructive to have them be part of the JAX API:

def is_tracer(obj: Any) -> bool:
    return isinstance(obj, jax.core.Tracer)

def assert_concrete(obj: Any, obj_name: str = ""):
    if is_tracer(obj):
        if not obj_name:
            obj_name = "<Unnamed>"

        _msg = f"Object `{obj_name}` needs to be concrete, but was a Tracer. "
        raise ConcretizationTypeError(obj, _msg)

I then use these at the beginning of a function:

def test_function(array, concrete_num):
    assert_concrete(concrete_num, "concrete_num")
    slicer = min(5, concrete_num)
    return array[slicer:]
jakevdp commented 1 year ago

Thanks for raising this issue!

Ideas like this have come up before, and we've generally decided that it's not a great idea to provide public APIs for checking whether a value is traced. Why? Because in general branching a function's logic based on the output of something like is_tracer would lead to functions that will silently return unexpected results under autodiff and other JAX transformations. Because of this, it's probably not a great idea to provide a public API that might encourage users to write this kind of error-prone code.

That said, there is sometimes need for the kind of assertion you mention, and for that purpose we have an existing utility that is used throughout JAX's own codebase:

import jax
from jax import core

def f(x):
  x = core.concrete_or_error(None, x, "The problem arose with the x argument.")
  return x

print(f(1))
# 1

jax.jit(f)(1)
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is
# expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
# The problem arose with the x argument.
# The error occurred while tracing the function f at <ipython-input-3-a66dd97d9283>:5 for jit.
# This concrete value was not available in Python because it depends on the value of the argument x.

Does this fit your needs?

JeppeKlitgaard commented 1 year ago

Thank you for a prompt and thought-out response as always!

I can certainly see why is_tracer could lead to branching logic, though I still think it could be useful when exposing a function and wanting to raise a custom error:

def test_function(array, concrete_num):
    if is_tracer(concrete_num):
        # Or more detailed descriptions here
        raise ValueError("Parameter `concrete_num` cannot be a tracer")

    slicer = min(5, concrete_num)
    return array[slicer:]

Which is much nicer than having to try ... except using core.concrete_or_error. Perhaps is_tracer could still be useful, but should be paired with sufficient "sharp bits" warnings in the documentation?

Additionally I am under the impression that all of core is considered to be internal only? The namespace is not documented in the public API documentation at the moment.

If a public API version of concrete_or_error was made, maybe the function signature could be simplified? For the use-case I had in mind I don't think having the force parameter would be ergonomic.

jakevdp commented 1 year ago

There are already a lot of sharp bits: I’d hesitate to add another without good reason. How would you anticipate an is_tracer function would be useful, aside from raising an error?

JeppeKlitgaard commented 1 year ago

I think adding detailed errors and catching bad input parameters early is the only reason I can come up with now.

I perhaps did myself a disservice by calling it a 'sharp bit'. I think this does stand apart from other JAX sharp bits in that they usually arise when a user tries to do something in JAX that would be allowed in 'normal' Python/Numpy operations – for example accidentally passing a tracer where a concrete value is needed. is_tracer is different in the sense that the function is only ever used with the clear intent of "I want to do something specific in the case that this is a tracer", which I would think means it wouldn't be subject to too much misuse. If it were to be implemented as part of the API, it could be documented that misuse of is_tracer could lead to branching logic depending on whether something is jitted or not, or how it is called.

Separate from is_tracer, I think assert_concrete would be less controversial as it can't (easily, always possible to try...catch your way into trouble) lead to branching logic, but would be a very convenient way to make functions "fail early and loudly" – when used this could even dull some of JAX's other sharp bits a little!

soraros commented 1 year ago

@JeppeKlitgaard I'd vote against having a is_tracer as part of the public api. The name itself is already troublesome, as I have the impression that the JAX devs treat the whole Tracer-naming thing as internal.

As for assert_concrete, I think it is indeed a useful tool. However, doesn't concrete_or_error already guarantees the input value is concrete, and some more? If you are talking about exposing a public version of it, then I agree (have voiced something similar before), maybe improve the type annotation a bit as well. Otherwise, I don't see the point providing something a less powerful. BTW, looking at the implementation of concrete_or_error, your assert_concrete doesn't seem to handle ConcreteArray correctly.

JeppeKlitgaard commented 1 year ago

@soraros Thanks for your input!

Perhaps I am not understanding correctly, but I was under the impression that the Tracer/tracing terminology is somewhat the end user of JAX would be expected to understand. It is, for example, mentioned here: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations

I think you're right that my version of assert_concrete does not catch ConcreteArray – I will take this as a good argument to have something like assert_concrete (or concrete_or_error) be part of a public API!

Most importantly I think JAX should expose and document some way for end-users to catch the scenario where a traced value is passed where a concrete value is needed – jitting would fail anyway, but it isn't as immediately obvious from the stack trace what went wrong. Being able to fail early and loudly would be helpful. Whether this is done through a decorator with argnums/names, assert_concrete at beginning of a function or by catching with something like:

assert not is_tracer(concrete_arg), "concrete_arg needs to be concrete"

Having is_tracer (perhaps even better: is_concrete) would allow for the above assert statement, which I think is the most ergonomic and economic way of doing this, though more elaborate errors might be desirable when exposing a function through a library. I think the assert-approach would be quite nice in a notebook for example.

PhilipVinc commented 1 year ago

As an end user, I agree that you want to steer users towards writing correct code, and those utilities usually do more damage than good.

However, as a library author (of a library built on top of jax), I know that concrete_or_erroris useful to help throw the right error to users. For example: I have some Numba code that I can't convert to jax, and need users to provide this function concrete values. It's supposed to be always called outside of jax-jit boundaries, but users inevitably get this wrong. At the same time I don't want to throw an error for a concrete jax array, because it would be annoying for the user to call np.asarray on each of its inputs.

Assuming you provide concrete_or_error, I think you could also provide ìs_concrete, which also be useful for example to branch between two possibly different versions of logic that optimise or not depending on the value being concrete or not.

Moreover, I know that PennyLane tries to support both jax and PyTorch, and they end up having to check whether an input is a tracer or not to support both backends correctly. That's surely a non-standard usage, but interfacing jax with quantum computers is non standard so 🤷 .

IMO, a good compromise would be to have a namespace (e.g. jax.advanced) that is still covered by API stability guarantees so we can build libraries on top of it without fearing every new minor jax release but that is clearly marked to end user as you most likely don't want to touch this.

JeppeKlitgaard commented 1 year ago

I quite like the idea of jax.advanced (though one might argue that because of JAX sharp bits, most of JAX might fit into such a namespace!).

I do think the use-case of assert is_concrete(arg) would be fairly common though. When adding new parameters to functions that are wrapped in transformations, one can quite easily lose track of which arguments are static and end up with a long stacktrace and a trip to https://jax.readthedocs.io/en/latest/errors.html before realising a silly mistake

jakevdp commented 1 year ago

I think you could also provide ìs_concrete, which also be useful for example to branch between two possibly different versions of logic that optimise or not depending on the value being concrete or not.

This is exactly the kind of thing we don't want users to do, because it breaks the assumptions that JAX transformations are built on, and would make it very easy to write code that silently returns the wrong results. Anyone who's sophisticated enough to use this kind of logic correctly is sophisticated enough to figure out how to check if a value is traced if they need to.

jakevdp commented 1 year ago

@PhilipVinc for your use case where values must be concrete in order to be passed to numba, I'd suggest this:

def func_whose_values_must_be_concrete(x):
  try:
    x = np.asarray(x)
  except TypeError as err:
    raise TypeError("input must be concrete") from err

In the end, what you need is a value that can be converted to a numpy array in order to be input into numba. Trying to do so and then catching the error is arguably the most "pythonic" way do do this, and doesn't require any special JAX APIs.

Alternatively, you could use jax.pure_callback to call-back to numba in a way that's maximally compatible with JAX transformations.

soraros commented 1 year ago

@JeppeKlitgaard I think trace/tracing/traced are somewhat standard terminology, and since JAX uses a tracing jit, it's only natural to use them. On the other hand, the fact that traced values being wrapped in Tracer objects is more contingent. (I might be wrong about this though.)

IMO, comparing with concrete_or_error, ìs_concrete doesn't provide enough usability. Consider the following code:

c = concrete_or_error(f, a, "some context")

It says that we can convert/and have actually converted a to a concrete value c of some concrete type with the help of some function f. Noticeably, ConcreteArray can be considered static even when f is None. Also notice how ìs_concrete and is_tracer are not logical negation of each other, which is another argument against is_tracer. Naked ìs_concrete will leave a untouched, and we can't use its value. If we were to add this new function (which I don't think we should), a more accurate name maybe is_concretizable.

... When adding new parameters to functions that are wrapped in transformations, one can quite easily lose track of which arguments are static and end up with a long stack trace and a trip to jax.readthedocs.io/en/latest/errors.html before realising a silly mistake

Agreed, and this exact use case is covered by core.concrete_or_error, with many examples to be found in jax/_src/numpy/lax_numpy.py.

JeppeKlitgaard commented 1 year ago

@soraros I think I confusingly stated is_tracer and assert_concrete in way that it could be interpreted as if I was suggesting these exact implementations – that was not my intention. I agree that core.concrete_or_error is probably what I am looking for, though I am advocating for it being made part of the public API (or some variation on it).

concrete_or_error raises an error, which is convenient in many cases. In those cases where a more custom error message (beyond what can be put in context of concrete_or_error, I could see is_concretizable being helpful.

Exposing concrete_or_error I think is a good way to emphasise to users that they need to be cautious about traced values – this becomes second-nature when working with JAX for a bit, but for new users the concept is, I found, one of the things that trip people up. Having these assertions in the top of a function also makes it immediately obvious when reading the function that it is made with certain constraints in mind.