Hello. I'm implementing the Mixtral models with Jax and Flax and there's a problem with the scan function at here
and I get this error Jax transforms and Flax models cannot be mixed.
System information
OS Platform and Distribution Ubuntu 23.04
Name: flax
Version: 0.7.5
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team flax-dev@google.com
License:
Location: /home/erfan/venv/lib/python3.11/site-packages
Requires: jax, msgpack, numpy, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: EasyDeL, FJFormer
File "/home/erfan/PycharmProjects/EasyDeL/lib/python/EasyDel/modules/mixtral/modelling_mixtral_flax.py", line 449, in expert_layer_forward
forward_hidden_state = nn.cond(
^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 1353, in cond
return lift_direct_transform(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 487, in lift_direct_transform
return decorator_lift_transform(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 426, in wrapped_fn
return trafo_fn(module_scopes, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 1298, in _cond_wrapper
return lift.cond(
^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/lift.py", line 1085, in cond
return pack(inner, (variables,), (variables,), (rngs,), name='cond')(scope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/lift.py", line 148, in wrapper
scope._validate_trace_level()
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/scope.py", line 545, in _validate_trace_level
tracers.check_trace_level(self.trace_level)
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/tracers.py", line 36, in check_trace_level
raise errors.JaxTransformError()
flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
Process finished with exit code 1
Hello. I'm implementing the Mixtral models with Jax and Flax and there's a problem with the scan function at here and I get this error Jax transforms and Flax models cannot be mixed.
System information
Name: flax Version: 0.7.5 Summary: Flax: A neural network library for JAX designed for flexibility Home-page: Author: Author-email: Flax team flax-dev@google.com License: Location: /home/erfan/venv/lib/python3.11/site-packages Requires: jax, msgpack, numpy, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions Required-by: EasyDeL, FJFormer
Name: jax Version: 0.4.23 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/erfan/venv/lib/python3.11/site-packages Requires: ml-dtypes, numpy, numpy, opt-einsum, scipy Required-by: chex, distrax, EasyDeL, FJFormer, flax, optax, orbax-checkpoint, rlax
Name: jaxlib Version: 0.4.23 Summary: XLA library for JAX Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/erfan/venv/lib/python3.11/site-packages Requires: ml-dtypes, numpy, scipy Required-by: chex, distrax, EasyDeL, FJFormer, optax, orbax-checkpoint, rlax
Problem you have encountered:
What you expected to happen:
Logs, error messages, etc:
Steps to reproduce:
code to get this error