Open patrick-kidger opened 4 months ago
What was the error you saw? I see this
In [1]: import jax, jax.numpy as jnp, numpy as np, jax._src.test_util as jtu
In [2]: jax.jit(hash)(1)
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
Cell In[2], line 1
----> 1 jax.jit(hash)(1)
[... skipping hidden 15 frame]
File ~/venv/lib/python3.11/site-packages/jax/_src/dtypes.py:298, in _scalar_type_to_dtype(typ, value)
296 if typ is int and value is not None:
297 if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
--> 298 raise OverflowError(f"Python int {value} too large to convert to {dtype}")
299 return dtype
OverflowError: Python int 8750429449870 too large to convert to int32
The error on 0.4.8 is
❯ python
Python 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.__version__
'0.4.8'
>>> jax.jit(hash)(1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/pjit.py", line 238, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/pjit.py", line 180, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
^^^^^^^^^^^^^^^^
File ".../jax/_src/api.py", line 311, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/pjit.py", line 480, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
^^^^^^^^^^^^
File ".../jax/_src/pjit.py", line 918, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/linear_util.py", line 322, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File ".../jax/_src/pjit.py", line 874, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/interpreters/partial_eval.py", line 2049, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/interpreters/partial_eval.py", line 2066, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../jax/_src/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax._src.traceback_util.UnfilteredStackTrace: TypeError: unhashable type: 'DynamicJaxprTracer'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'DynamicJaxprTracer'
On 0.4.29 I see no error at all. And I think that's the problem! To be clear, I think expected behaviour is for tracers not to be hashable.
I understand the error is different, but I do see an error on 0.4.29 (the one I posted above). Are you saying you see no error at all?
https://github.com/google/jax/pull/21826 should fix
Great! Thank you for the swift fix :)
(You can just __hash__ = None
in the class scope though right?)
Sure, it's a matter of preference :)
I noticed this a couple months ago and tried to fix, but it turns out the hashability of tracers is used by a number of internal codepaths
All tests are passing on my change.
I see some internal failures on your change.
I think it's the right change to make, but we need to fix those failures in order to land it.
My previous fix was #20903, but we couldn't land it because of those failures.
Jake was right. There are a bunch of failures internally :(
:(
Description
jax.jit(hash)(1)
does not error on any recent version of JAX. I went back and checked and the last time this seems to have produced an error was0.4.8
.Is this intended? It means that tracers do not duck-type as either concrete JAX arrays or as numpy arrays. In particular see #21825, which this caused.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.29 jaxlib: 0.4.29 numpy: 1.26.4 python: 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Air.localdomain', release='22.5.0', version='Darwin Kernel Version 22.5.0: Mon Apr 24 20:52:43 PDT 2023; root:xnu-8796.121.2~5/RELEASE_ARM64_T8112', machine='arm64')