Closed sanderland closed 3 months ago
Flax exceptions can not be pickled. This makes it impossible to trace where/why errors are occurring when calling flax via ray.
Name: flax Version: 0.8.3 Summary: Flax: A neural network library for JAX designed for flexibility Home-page: Author: Author-email: Flax team <flax-dev@google.com> License: Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions Required-by: dm-haiku, fax --- Name: jax Version: 0.4.28 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages Requires: ml-dtypes, numpy, opt-einsum, scipy Required-by: chex, fax, flax, optax, orbax-checkpoint, rax --- Name: jaxlib Version: 0.4.28 Summary: XLA library for JAX Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages Requires: ml-dtypes, numpy, scipy Required-by: chex, fax, optax, orbax-checkpoint, rax
Flax Exceptions to support pickle, cf this guide
In production code:
site-packages/ray/exceptions.py", line 49, in from_ray_exception return pickle.loads(ray_exception.serialized_exception) TypeError: ScopeParamShapeError.__init__() missing 3 required positional arguments: 'scope_path', 'value_shape', and 'init_shape'
import flax.linen as nn from jax import random from flax.linen.initializers import lecun_normal from jax import lax import pickle class NoBiasDense(nn.Module): features: int = 8 @nn.compact def __call__(self, x): kernel = self.param('kernel', lecun_normal(), (x.shape[-1], self.features)) # <--- Exception from flax docs example y = lax.dot_general(x, kernel, (((x.ndim - 1,), (0,)), ((), ()))) return y variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1))) try: _ = NoBiasDense().apply(variables, jnp.ones((5, 5))) except Exception as e: str = pickle.dumps(e) obj = pickle.loads(str) # <--- pickle exception
Problem you have encountered:
Flax exceptions can not be pickled. This makes it impossible to trace where/why errors are occurring when calling flax via ray.
System information
What you expected to happen:
Flax Exceptions to support pickle, cf this guide
Logs, error messages, etc:
In production code:
Steps to reproduce: