erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

TPU-v3 Kaggle not working after update #161

Closed s-smits closed 3 weeks ago

s-smits commented 3 weeks ago

Describe the bug It would be great to keep support for TPU-v3's on Kaggle. After 0.0.66 I get this error:


ImportError Traceback (most recent call last) File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:58, in _LazyModule._get_module(self, module_name) 57 try: ---> 58 return importlib.import_module("." + module_name, self.name) 59 except Exception as e:

File /usr/local/lib/python3.10/importlib/init.py:126, in import_module(name, package) 125 level += 1 --> 126 return _bootstrap._gcd_import(name[level:], package, level)

File :1050, in _gcd_import(name, package, level)

File :1027, in _find_andload(name, import)

File :992, in _find_and_loadunlocked(name, import)

File :241, in _call_with_frames_removed(f, *args, **kwds)

File :1050, in _gcd_import(name, package, level)

File :1027, in _find_andload(name, import)

File :1006, in _find_and_loadunlocked(name, import)

File :688, in _load_unlocked(spec)

File :883, in exec_module(self, module)

File :241, in _call_with_frames_removed(f, *args, **kwds)

File /usr/local/lib/python3.10/site-packages/easydel/modules/init.py:35 1 from . import ( 2 llama, 3 deepseek_v2, (...) 32 mistral 33 ) ---> 35 from .auto_easydel_model import ( 36 AutoEasyDeLModelForCausalLM as AutoEasyDeLModelForCausalLM, 37 AutoEasyDeLConfig as AutoEasyDeLConfig, 38 AutoShardAndGatherFunctions as AutoShardAndGatherFunctions 39 )

File /usr/local/lib/python3.10/site-packages/easydel/modules/auto_easydel_model.py:12 11 import jax.numpy ---> 12 from fjformer import match_partition_rules, make_shard_and_gather_fns 14 from flax.traverse_util import unflatten_dict

File /usr/local/lib/python3.10/site-packages/fjformer/init.py:68 67 from . import pallas_operations as pallas_operations ---> 68 from . import optimizers as optimizers 69 from . import linen as linen

File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/init.py:1 ----> 1 from .adamw import ( 2 get_adamw_with_cosine_scheduler as get_adamw_with_cosine_scheduler, 3 get_adamw_with_warm_up_cosine_scheduler as get_adamw_with_warm_up_cosine_scheduler, 4 get_adamw_with_warmup_linear_scheduler as get_adamw_with_warmup_linear_scheduler, 5 get_adamw_with_linear_scheduler as get_adamw_with_linear_scheduler 6 ) 7 from .lion import ( 8 get_lion_with_cosine_scheduler as get_lion_with_cosine_scheduler, 9 get_lion_with_with_warmup_linear_scheduler as get_lion_with_with_warmup_linear_scheduler, 10 get_lion_with_warm_up_cosine_scheduler as get_lion_with_warm_up_cosine_scheduler, 11 get_lion_with_linear_scheduler as get_lion_with_linear_scheduler 12 )

File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/adamw.py:3 2 import chex ----> 3 import optax 6 def get_adamw_with_cosine_scheduler( 7 steps: int, 8 learning_rate: float = 5e-5, (...) 16 17 ):

File /usr/local/lib/python3.10/site-packages/optax/init.py:17 15 """Optax: composable gradient processing and optimization, in JAX.""" ---> 17 from optax import contrib 18 from optax import losses

File /usr/local/lib/python3.10/site-packages/optax/contrib/init.py:21 20 from optax.contrib.complex_valued import SplitRealAndImaginaryState ---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw 22 from optax.contrib.dadapt_adamw import DAdaptAdamWState

File /usr/local/lib/python3.10/site-packages/optax/contrib/dadapt_adamw.py:27 26 from optax._src import base ---> 27 from optax._src import utils 30 class DAdaptAdamWState(NamedTuple):

File /usr/local/lib/python3.10/site-packages/optax/_src/utils.py:22 21 import jax.numpy as jnp ---> 22 import jax.scipy.stats.norm as multivariate_normal 24 from optax._src import linear_algebra

ImportError: cannot import name 'stats' from 'jax.scipy' (/usr/local/lib/python3.10/site-packages/jax/scipy/init.py)

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last) Cell In[8], line 3 1 import transformers ----> 3 from easydel import ( 4 AutoEasyDeLModelForCausalLM, 5 TrainArguments, 6 EasyDeLOptimizers, 7 EasyDeLSchedulers, 8 EasyDeLGradientCheckPointers, 9 EasyDeLState, 10 EasyDeLXRapTureConfig, 11 CausalLanguageModelTrainer, 12 get_modules_by_type, 13 easystate_to_huggingface_model, 14 ) 15 from datasets import load_dataset 16 from flax.core import FrozenDict

File :1075, in _handlefromlist(module, fromlist, import, recursive)

File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:48, in _LazyModule.getattr(self, name) 46 value = self._get_module(name) 47 elif name in self._class_to_module.keys(): ---> 48 module = self._get_module(self._class_to_module[name]) 49 value = getattr(module, name) 50 else:

File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:60, in _LazyModule._get_module(self, module_name) 58 return importlib.import_module("." + module_name, self.name) 59 except Exception as e: ---> 60 raise RuntimeError( 61 f"Failed to import {self.name}.{module_name} because of the following error (look up to see its" 62 f" traceback):\n{e}" 63 ) from e

RuntimeError: Failed to import easydel.modules.auto_easydel_model because of the following error (look up to see its traceback): cannot import name 'stats' from 'jax.scipy' (/usr/local/lib/python3.10/site-packages/jax/scipy/init.py)

To Reproduce

!pip install fjformer datasets gradio wandb sentencepiece git+https://github.com/huggingface/transformers -U -q #transformers=4.41.0 
!pip install jax[tpu]==0.4.23 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html #current version 0.4.28 but using .22 for stability
!pip install tensorflow --upgrade
HF_TOKEN = "HF_TOKEN"  # Replace with your actual Hugging Face token
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<HF_TOKEN>')"
!apt-get update && apt-get upgrade -y && apt-get install golang -y
!pip install git+https://github.com/erfanzar/EasyDeL.git

import transformers

from easydel import (
    AutoEasyDeLModelForCausalLM,
    TrainArguments,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    EasyDeLState,
    EasyDeLXRapTureConfig,
    CausalLanguageModelTrainer,
    get_modules_by_type,
    easystate_to_huggingface_model,
)
from datasets import load_dataset
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp, sharding
import jax
from transformers import AutoConfig
from huggingface_hub import HfApi
from typing import Literal

PartitionSpec = sharding.PartitionSpec
api = HfApi()

sharding_axis_dims = (1, 1, 1, -1)
max_length = 4096
input_shape = (1, max_length)
# input_shape = (8, 8) second try
training_run = 1

pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
pretrained_model_name_or_path_tokenizer = pretrained_model_name_or_path
new_repo_id = f"ssmits/Falcon2-5.5B-Dutch-Chat-cp0"

dtype = jnp.bfloat16
use_lora = False
lora_dim = 16
fully_fine_tune_parameters = False
lora_fine_tune_parameters = False
block_size = 512
attn_mechanism = "sharded_vanilla"

attention_partitions = dict(
    query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    value_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    bias_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, None),
    attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
)

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path,
    device=jax.devices('cpu')[0],
    input_shape=input_shape,
    device_map="auto",
    sharding_axis_dims=sharding_axis_dims,
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism=attn_mechanism,
        **attention_partitions
    ),
    **attention_partitions
)

config = model.config

model_use_tie_word_embedding = config.tie_word_embeddings

model_parameters = FrozenDict({"params": params})
erfanzar commented 3 weeks ago

hi update your jax version to 0.4.28

s-smits commented 3 weeks ago

After doing that it installs correctly, but freezes when importing the easydel libraires:

Loaded pretrained model: ssmits/Falcon2-5.5B-Dutch
Input shape: (1, 4096)
Attention partitions: {'query_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'key_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'value_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'bias_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, None), 'attention_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp')}
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[2], line 32
     22 attention_partitions = dict(
     23     query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
     24     key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
   (...)
     27     attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
     28 )
     30 print(f"Attention partitions: {attention_partitions}")
---> 32 model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
     33     pretrained_model_name_or_path,
     34     device=jax.devices('cpu')[0],
     35     input_shape=input_shape,
     36     device_map="auto",
     37     sharding_axis_dims=sharding_axis_dims,
     38     config_kwargs=dict(
     39         use_scan_mlp=False,
     40         attn_mechanism=attn_mechanism,
     41         **attention_partitions
     42     ),
     43     **attention_partitions
     44 )
     46 print(f"Loaded model with params shape: {jax.tree_util.tree_map(lambda x: x.shape, params)}")
     48 config = model.config

NameError: name 'AutoEasyDeLModelForCausalLM' is not defined
s-smits commented 3 weeks ago

Should I make a separate issue for this?

erfanzar commented 3 weeks ago

no it's fine take a look at this

https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example

s-smits commented 3 weeks ago

Yes, thank you, it's working!