jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.58k stars 2.81k forks source link

Can not export function with input/output of typed keys #24143

Open jjyyxx opened 1 month ago

jjyyxx commented 1 month ago

Description

Precisely, function IO with typed keys can be exported, but cannot be serialized, since serialization.fbs does not contain relevant dtypes.

import jax, jax.export

@jax.jit
def f(key):
    return key

key = jax.random.key(0)

exported = jax.export.export(f)(key)  # Success

with open('exported.jax', 'wb') as g:
    g.write(exported.serialize())  # KeyError: key<fry>

This can be temporarily workarounded with

from jax._src.prng import prngs, KeyTy
from jax._src.export.serialization import _dtype_to_dtype_kind, _dtype_kind_to_dtype

last = max(_dtype_to_dtype_kind.values())
for prng in prngs.values():
    dtype = KeyTy(prng)
    last += 1
    _dtype_to_dtype_kind[dtype] = last
    _dtype_kind_to_dtype[last] = dtype

If extensibility with jax.extend.random.define_prng_impl is not at priority, I guess it is OK to hard code fry, rbg, and urbg to serialization.fbs, and I could raise a PR for this.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.0.2
python: 3.11.10 | packaged by conda-forge | (main, Sep 22 2024, 14:10:38) [GCC 13.3.0]
jax.devices (2 total, 2 local): [CudaDevice(id=0) CudaDevice(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='x', release='6.8.0-45-generic', version='#45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024', machine='x86_64')
gnecula commented 1 month ago

@froystig Any suggestion for how to handle this? We would have to promise backwards compatibility of the serialized artifacts.

jjyyxx commented 1 month ago

@gnecula Besides, may I ask about the current status of user-defined custom dtypes in JAX (not via __jax_array__)? Are there such public API?

froystig commented 1 month ago

[...] may I ask about the current status of user-defined custom dtypes in JAX (not via __jax_array__)? Are there such public API?

Not currently. What do you have in mind?

froystig commented 1 month ago

@froystig Any suggestion for how to handle this?

As mentioned in the original post, if we only care to support built-in PRNG implementations (which I think we have a case for), then how about we indeed add them as de/serializable dtypes? https://github.com/jax-ml/jax/blob/3fc4ba29ea2d5498f90c6ef8fda6edb49c1d835c/jax/_src/export/serialization.fbs#L44-L72 I assume a definition here isn't enough and that we also need to add logic to unwrap and wrap the keys somewhere.

We would have to promise backwards compatibility of the serialized artifacts.

What would that mean in this case? I would expect that serializing with typed keys never worked, considering this issue. In that case, no artifacts out there have them, and so their deserialization would not change. Does that sound correct?

jjyyxx commented 1 month ago

[...] may I ask about the current status of user-defined custom dtypes in JAX (not via __jax_array__)? Are there such public API?

Not currently. What do you have in mind?

Not really. I was thinking that custom dtypes are similar to typed keys. If they are not publicly supported, the discussion can be restricted to typed keys only.

I assume a definition here isn't enough and that we also need to add logic to unwrap and wrap the keys somewhere.

The workaround I posted is enough to make export with typed key work, but I was unaware of compatibility issues at that moment. If the implementation chooses to unwrap and wrap the keys instead, I guess it may be unnecessary to modify serialization.fbs?