Open jjyyxx opened 1 month ago
@froystig Any suggestion for how to handle this? We would have to promise backwards compatibility of the serialized artifacts.
@gnecula Besides, may I ask about the current status of user-defined custom dtypes in JAX (not via __jax_array__
)? Are there such public API?
[...] may I ask about the current status of user-defined custom dtypes in JAX (not via
__jax_array__
)? Are there such public API?
Not currently. What do you have in mind?
@froystig Any suggestion for how to handle this?
As mentioned in the original post, if we only care to support built-in PRNG implementations (which I think we have a case for), then how about we indeed add them as de/serializable dtypes? https://github.com/jax-ml/jax/blob/3fc4ba29ea2d5498f90c6ef8fda6edb49c1d835c/jax/_src/export/serialization.fbs#L44-L72 I assume a definition here isn't enough and that we also need to add logic to unwrap and wrap the keys somewhere.
We would have to promise backwards compatibility of the serialized artifacts.
What would that mean in this case? I would expect that serializing with typed keys never worked, considering this issue. In that case, no artifacts out there have them, and so their deserialization would not change. Does that sound correct?
[...] may I ask about the current status of user-defined custom dtypes in JAX (not via
__jax_array__
)? Are there such public API?Not currently. What do you have in mind?
Not really. I was thinking that custom dtypes are similar to typed keys. If they are not publicly supported, the discussion can be restricted to typed keys only.
I assume a definition here isn't enough and that we also need to add logic to unwrap and wrap the keys somewhere.
The workaround I posted is enough to make export with typed key work, but I was unaware of compatibility issues at that moment. If the implementation chooses to unwrap and wrap the keys instead, I guess it may be unnecessary to modify serialization.fbs
?
Description
Precisely, function IO with typed keys can be exported, but cannot be serialized, since
serialization.fbs
does not contain relevant dtypes.This can be temporarily workarounded with
If extensibility with
jax.extend.random.define_prng_impl
is not at priority, I guess it is OK to hard codefry
,rbg
, andurbg
toserialization.fbs
, and I could raise a PR for this.System info (python version, jaxlib version, accelerator, etc.)