google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.11k stars 645 forks source link

Jax transforms and Flax models cannot be mixed #3665

Closed erfanzar closed 7 months ago

erfanzar commented 9 months ago

Hello. I'm implementing the Mixtral models with Jax and Flax and there's a problem with the scan function at here and I get this error Jax transforms and Flax models cannot be mixed.

System information

Name: flax Version: 0.7.5 Summary: Flax: A neural network library for JAX designed for flexibility Home-page: Author: Author-email: Flax team flax-dev@google.com License: Location: /home/erfan/venv/lib/python3.11/site-packages Requires: jax, msgpack, numpy, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions Required-by: EasyDeL, FJFormer

Name: jax Version: 0.4.23 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/erfan/venv/lib/python3.11/site-packages Requires: ml-dtypes, numpy, numpy, opt-einsum, scipy Required-by: chex, distrax, EasyDeL, FJFormer, flax, optax, orbax-checkpoint, rlax

Name: jaxlib Version: 0.4.23 Summary: XLA library for JAX Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/erfan/venv/lib/python3.11/site-packages Requires: ml-dtypes, numpy, scipy Required-by: chex, distrax, EasyDeL, FJFormer, optax, orbax-checkpoint, rlax

Problem you have encountered:

What you expected to happen:

Logs, error messages, etc:

File "/home/erfan/PycharmProjects/EasyDeL/lib/python/EasyDel/modules/mixtral/modelling_mixtral_flax.py", line 449, in expert_layer_forward
    forward_hidden_state = nn.cond(
                           ^^^^^^^^
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 1353, in cond
    return lift_direct_transform(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 487, in lift_direct_transform
    return decorator_lift_transform(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 426, in wrapped_fn
    return trafo_fn(module_scopes, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 1298, in _cond_wrapper
    return lift.cond(
           ^^^^^^^^^^
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/lift.py", line 1085, in cond
    return pack(inner, (variables,), (variables,), (rngs,), name='cond')(scope)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/lift.py", line 148, in wrapper
    scope._validate_trace_level()
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/scope.py", line 545, in _validate_trace_level
    tracers.check_trace_level(self.trace_level)
  File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/tracers.py", line 36, in check_trace_level
    raise errors.JaxTransformError()
flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

Process finished with exit code 1

Steps to reproduce:

code to get this error

pip install git+https://github.com/erfanzar/EasyDeL.git
import copy
import os
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
import jax
from EasyDel import MixtralConfig, FlaxMixtralForCausalLM
from EasyDel.transform.easydel_transform import huggingface_to_easydel
from jax import numpy as jnp
from transformers import MixtralForCausalLM
import torch
import numpy as np

def main():
    torch.manual_seed(42)
    seq_len = 128
    config = MixtralConfig(
        hidden_size=256,
        num_attention_heads=8,
        num_hidden_layers=1,
        num_key_value_heads=4,
        intermediate_size=512,
        num_local_experts=8,
        max_position_embeddings=seq_len
    )
    batch_size = len(jax.devices())

    torch_model = MixtralForCausalLM(
        config=copy.deepcopy(config)
    )
    params = {"params":
        huggingface_to_easydel(
            torch_model.state_dict(),
            embedding_layer_names=["embed_tokens"],
            device=jax.devices("cpu")[0]
        )
    }

    np_random_input_ids = np.random.randint(0, config.vocab_size, (batch_size, seq_len))
    input_ids = torch.from_numpy(np_random_input_ids).reshape(batch_size, -1).to(torch.long)
    flax_input_ids = jnp.asarray(np_random_input_ids, dtype=jnp.int32).reshape(batch_size, -1)
    torch_output = torch_model(
        input_ids=input_ids
    )
    torch_output = torch_output.logits.cpu().detach().numpy()
    config.add_jax_args()
    config.add_basic_configurations(
        use_shard_map=True
    )

    try:
        flax_model = FlaxMixtralForCausalLM(
            config=config,
            dtype=jnp.float32,
            param_dtype=jnp.float32,
            _do_init=False,
            input_shape=(batch_size, seq_len)
        )
        flax_output = flax_model(
            input_ids=flax_input_ids,
            params=params,
        )
        res = jnp.allclose(torch_output, flax_output.logits, atol=1e-5)
        print("Mixtral Huggingface Predictions :\n", torch_output,
              "\nEasyDel Predictions: \n", flax_output.logits)
        if res:  # A Little Bit of humor
            print("\033[1;36mTest Passed Unfortunately 🥳")
        else:
            print("\033[1;31mTest Failed Successfully  🤕")
        error = jnp.mean(torch_output - flax_output.logits)
        print("Error : ", error)
    except TypeError as e:
        print(e.__str__())

if __name__ == "__main__":
    main()
erfanzar commented 7 months ago

I have fixed that a 5 or 6 weeks ago, anyway thanks for help!

gozde-ozcan commented 5 months ago

Hello, could you please share how did you fix this error?