jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.2k stars 2.77k forks source link

FlaxLlamaForCausalLMModule hanging on jax-metal #24221

Open alexlatif opened 4 days ago

alexlatif commented 4 days ago

Description

To reproduce the working state uncomment the device update to cpu

from transformers import AutoTokenizer
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state

# from llama import FlaxLLaMAForCausalLM  # From the ayaka14732/llama-2-jax repo
from transformers.models.llama.modeling_flax_llama import (
    FlaxLlamaForCausalLMModule,
    LlamaConfig,
)
from transformers.models.llama.tokenization_llama import LlamaTokenizer

# Download tokenizer and model
# jax.config.update("jax_platform_name", "cpu")
print(jax.devices())

tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2")
conf = LlamaConfig.from_pretrained("openlm-research/open_llama_3b_v2")
print(type(conf))
model = FlaxLlamaForCausalLMModule(conf)
print(type(model))

input_prompt = "The future of AI is"
input_ids = tokenizer(input_prompt, return_tensors="jax").input_ids

rng = jax.random.PRNGKey(0)
position_ids = jnp.broadcast_to(jnp.arange(input_ids.shape[-1]), input_ids.shape)

print(position_ids.device)

params = model.init(
    rng, input_ids, attention_mask=jnp.ones_like(input_ids), position_ids=position_ids
)["params"]

model_output = model.apply(
    {"params": params},
    input_ids,
    attention_mask=jnp.ones_like(input_ids),
    position_ids=position_ids,
)

print("Model output logits:", model_output.logits)

predicted_token_ids = jnp.argmax(model_output.logits, axis=-1)

predicted_text = tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)

print("Predicted text:", predicted_text)

System info (python version, jaxlib version, accelerator, etc.)


jaxlib: 0.4.34
numpy:  1.26.4
python: 3.11.0 (v3.11.0:deaf509e8f, Oct 24 2022, 14:43:23) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Alessandros-Air', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:19:22 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8112', machine='arm64')```
rajasekharporeddy commented 3 days ago

Hi @alexlatif

I tested the provided code with JAX-metal on a Macbook Pro M1 Pro. While there were no hanging issues, model.init and model.apply took longer than the CPU version. Please find the attached screenshots below:

image image image

Thank you.

alexlatif commented 3 days ago

You're correct in that eventually it does run. However on Macbook Air M2 Sonoma 14.4.1 this took ~5 mins. Any insight on why it's so much slower on metal?