jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.79k forks source link

Tracers are now hashable. #21824

Open patrick-kidger opened 4 months ago

patrick-kidger commented 4 months ago

Description

jax.jit(hash)(1) does not error on any recent version of JAX. I went back and checked and the last time this seems to have produced an error was 0.4.8.

Is this intended? It means that tracers do not duck-type as either concrete JAX arrays or as numpy arrays. In particular see #21825, which this caused.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.29 jaxlib: 0.4.29 numpy: 1.26.4 python: 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Air.localdomain', release='22.5.0', version='Darwin Kernel Version 22.5.0: Mon Apr 24 20:52:43 PDT 2023; root:xnu-8796.121.2~5/RELEASE_ARM64_T8112', machine='arm64')

yashk2810 commented 4 months ago

What was the error you saw? I see this

In [1]: import jax, jax.numpy as jnp, numpy as np, jax._src.test_util as jtu

In [2]: jax.jit(hash)(1)
---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
Cell In[2], line 1
----> 1 jax.jit(hash)(1)

    [... skipping hidden 15 frame]

File ~/venv/lib/python3.11/site-packages/jax/_src/dtypes.py:298, in _scalar_type_to_dtype(typ, value)
    296 if typ is int and value is not None:
    297   if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
--> 298     raise OverflowError(f"Python int {value} too large to convert to {dtype}")
    299 return dtype

OverflowError: Python int 8750429449870 too large to convert to int32
patrick-kidger commented 4 months ago

The error on 0.4.8 is

❯ python                     
Python 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.__version__
'0.4.8'
>>> jax.jit(hash)(1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/pjit.py", line 238, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
                                          ^^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/pjit.py", line 180, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
                                                 ^^^^^^^^^^^^^^^^
  File ".../jax/_src/api.py", line 311, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/pjit.py", line 480, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
                                                      ^^^^^^^^^^^^
  File ".../jax/_src/pjit.py", line 918, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
                                    ^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/linear_util.py", line 322, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File ".../jax/_src/pjit.py", line 874, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/interpreters/partial_eval.py", line 2049, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/interpreters/partial_eval.py", line 2066, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../jax/_src/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax._src.traceback_util.UnfilteredStackTrace: TypeError: unhashable type: 'DynamicJaxprTracer'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'DynamicJaxprTracer'

On 0.4.29 I see no error at all. And I think that's the problem! To be clear, I think expected behaviour is for tracers not to be hashable.

yashk2810 commented 4 months ago

I understand the error is different, but I do see an error on 0.4.29 (the one I posted above). Are you saying you see no error at all?

yashk2810 commented 4 months ago

https://github.com/google/jax/pull/21826 should fix

patrick-kidger commented 4 months ago

Great! Thank you for the swift fix :) (You can just __hash__ = None in the class scope though right?)

yashk2810 commented 4 months ago

Sure, it's a matter of preference :)

jakevdp commented 4 months ago

I noticed this a couple months ago and tried to fix, but it turns out the hashability of tracers is used by a number of internal codepaths

yashk2810 commented 4 months ago

All tests are passing on my change.

jakevdp commented 4 months ago

I see some internal failures on your change.

I think it's the right change to make, but we need to fix those failures in order to land it.

jakevdp commented 4 months ago

My previous fix was #20903, but we couldn't land it because of those failures.

yashk2810 commented 4 months ago

Jake was right. There are a bunch of failures internally :(

patrick-kidger commented 4 months ago

:(