erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Kaggle training examples don't work #140

Closed jcole75 closed 2 months ago

jcole75 commented 2 months ago

I get an error when trying to run it. I tried upgrading scipy and that didn't help.

`AttributeError Traceback (most recent call last) Cell In[4], line 1 ----> 1 from EasyDel import ( 2 AutoEasyDelModelForCausalLM, 3 TrainArguments, 4 CausalLanguageModelTrainer, 5 EasyDelOptimizers, 6 EasyDelSchedulers, 7 EasyDelGradientCheckPointers, 8 EasyDelState, 9 EasyDeLXRapTureConfig, 10 get_modules_by_type, 11 easystate_to_huggingface_model, 12 SFTTrainer, 13 conversations_formatting_function, 14 AutoEasyDelConfig 15 ) 16 from datasets import load_dataset 17 from flax.core import FrozenDict

File /usr/local/lib/python3.10/site-packages/EasyDel/init.py:1 ----> 1 from .serve import ( 2 EasyServe as EasyServe, 3 EasyServeConfig as EasyServeConfig, 4 LLMBaseReq as LLMBaseReq, 5 GenerateAPIRequest as GenerateAPIRequest, 6 ConversationItem as ConversationItem, 7 ModelOutput as ModelOutput, 8 BaseModel as BaseModel, 9 EasyClient as EasyClient, 10 GradioUserInference as GradioUserInference, 11 ChatRequest as ChatRequest, 12 InstructRequest as InstructRequest, 13 PyTorchServer as PyTorchServer, 14 PyTorchServerConfig as PyTorchServerConfig, 15 JAXServer as JAXServer, 16 JAXServerConfig as JAXServerConfig 17 ) 19 from .modules.llama import ( 20 FlaxLlamaModel as FlaxLlamaModel, 21 FlaxLlamaForCausalLM as FlaxLlamaForCausalLM, (...) 25 VisionLlamaConfig as VisionLlamaConfig 26 ) 27 from .modules.gpt_j import ( 28 GPTJConfig as GPTJConfig, 29 FlaxGPTJForCausalLM as FlaxGPTJForCausalLM, 30 FlaxGPTJModel as FlaxGPTJModel, 31 )

File /usr/local/lib/python3.10/site-packages/EasyDel/serve/init.py:1 ----> 1 from .jax_serve import JAXServer, JAXServerConfig 2 from .torch_serve import PyTorchServer, PyTorchServerConfig 3 from .utils import ChatRequest, InstructRequest

File /usr/local/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:14 12 import uvicorn 13 from fastapi import FastAPI ---> 14 from fjformer import make_shard_and_gather_fns, match_partition_rules, with_sharding_constraint 16 from ..etils.etils import get_logger 17 from ..smi import get_mem, initialise_tracking

File /usr/local/lib/python3.10/site-packages/fjformer/init.py:68 60 from .utils import ( 61 JaxRNG as JaxRNG, 62 GenerateRNG as GenerateRNG, 63 init_rng as init_rng, 64 next_rng as next_rng, 65 ) 67 from . import pallas_operations as pallas_operations ---> 68 from . import optimizers as optimizers 69 from . import linen as linen 71 version = "0.0.46"

File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/init.py:1 ----> 1 from .adamw import ( 2 get_adamw_with_cosine_scheduler as get_adamw_with_cosine_scheduler, 3 get_adamw_with_warm_up_cosine_scheduler as get_adamw_with_warm_up_cosine_scheduler, 4 get_adamw_with_warmup_linear_scheduler as get_adamw_with_warmup_linear_scheduler, 5 get_adamw_with_linear_scheduler as get_adamw_with_linear_scheduler 6 ) 7 from .lion import ( 8 get_lion_with_cosine_scheduler as get_lion_with_cosine_scheduler, 9 get_lion_with_with_warmup_linear_scheduler as get_lion_with_with_warmup_linear_scheduler, 10 get_lion_with_warm_up_cosine_scheduler as get_lion_with_warm_up_cosine_scheduler, 11 get_lion_with_linear_scheduler as get_lion_with_linear_scheduler 12 ) 13 from .adafactor import ( 14 get_adafactor_with_cosine_scheduler as get_adafactor_with_cosine_scheduler, 15 get_adafactor_with_warm_up_cosine_scheduler as get_adafactor_with_warm_up_cosine_scheduler, 16 get_adafactor_with_warmup_linear_scheduler as get_adafactor_with_warmup_linear_scheduler, 17 get_adafactor_with_linear_scheduler as get_adafactor_with_linear_scheduler 18 )

File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/adamw.py:3 1 from typing import Optional 2 import chex ----> 3 import optax 6 def get_adamw_with_cosine_scheduler( 7 steps: int, 8 learning_rate: float = 5e-5, (...) 16 17 ): 18 """ 19 20 :param gradient_accumulation_steps: (...) 29 :return: Optimizer and Scheduler 30 """

File /usr/local/lib/python3.10/site-packages/optax/init.py:17 1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); (...) 13 # limitations under the License. 14 # ============================================================================== 15 """Optax: composable gradient processing and optimization, in JAX.""" ---> 17 from optax import contrib 18 from optax import losses 19 from optax import monte_carlo

File /usr/local/lib/python3.10/site-packages/optax/contrib/init.py:21 19 from optax.contrib.complex_valued import split_real_and_imaginary 20 from optax.contrib.complex_valued import SplitRealAndImaginaryState ---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw 22 from optax.contrib.dadapt_adamw import DAdaptAdamWState 23 from optax.contrib.mechanic import MechanicState

File /usr/local/lib/python3.10/site-packages/optax/contrib/dadapt_adamw.py:27 25 from optax import tree_utils 26 from optax._src import base ---> 27 from optax._src import utils 30 class DAdaptAdamWState(NamedTuple): 31 """State of the GradientTransformation returned by dadapt_adamw."""

File /usr/local/lib/python3.10/site-packages/optax/_src/utils.py:22 20 import jax 21 import jax.numpy as jnp ---> 22 import jax.scipy.stats.norm as multivariate_normal 24 from optax._src import linear_algebra 25 from optax._src import numerics

File /usr/local/lib/python3.10/site-packages/jax/scipy/stats/init.py:40 38 from jax.scipy.stats import gennorm as gennorm 39 from jax.scipy.stats import truncnorm as truncnorm ---> 40 from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde 41 from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata 42 from jax.scipy.stats import vonmises as vonmises

File /usr/local/lib/python3.10/site-packages/jax/_src/scipy/stats/kde.py:26 24 from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps 25 from jax._src.tree_util import register_pytree_node_class ---> 26 from jax.scipy import linalg, special 29 @_wraps(osp_stats.gaussian_kde, update_doc=False) 30 @register_pytree_node_class 31 @dataclass(frozen=True, init=False) 32 class gaussian_kde: 33 neff: Any

File /usr/local/lib/python3.10/site-packages/jax/scipy/linalg.py:18 1 # Copyright 2020 The JAX Authors. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); (...) 15 # Note: import as is required for names to be exported. 16 # See PEP 484 & https://github.com/google/jax/issues/7570 ---> 18 from jax._src.scipy.linalg import ( 19 block_diag as block_diag, 20 cholesky as cholesky, 21 cho_factor as cho_factor, 22 cho_solve as cho_solve, 23 det as det, 24 eigh as eigh, 25 eigh_tridiagonal as eigh_tridiagonal, 26 expm as expm, 27 expm_frechet as expm_frechet, 28 hessenberg as hessenberg, 29 inv as inv, 30 lu as lu, 31 lu_factor as lu_factor, 32 lu_solve as lu_solve, 33 polar as polar, 34 qr as qr, 35 rsf2csf as rsf2csf, 36 schur as schur, 37 sqrtm as sqrtm, 38 solve as solve, 39 solve_triangular as solve_triangular, 40 svd as svd, 41 toeplitz as toeplitz, 42 ) 44 from jax._src.third_party.scipy.linalg import ( 45 funm as funm, 46 ) 48 # Deprecations

File /usr/local/lib/python3.10/site-packages/jax/_src/scipy/linalg.py:403 399 del overwrite_b, debug, check_finite # unused 400 return _solve_triangular(a, b, trans, lower, unit_diagonal) --> 403 @_wraps(scipy.linalg.tril) 404 def tril(m: ArrayLike, k: int = 0) -> Array: 405 return jnp.tril(m, k) 408 @_wraps(scipy.linalg.triu) 409 def triu(m: ArrayLike, k: int = 0) -> Array:

AttributeError: module 'scipy.linalg' has no attribute 'tril'`

erfanzar commented 2 months ago

hi, thanks for using EasyDeL!

your problem will be fixed with just restarting the session

erfanzar commented 2 months ago

@jcole75 is your issue fixed?

jcole75 commented 2 months ago

Unfortunately, no. I restart it and I get the same issue.

erfanzar commented 2 months ago

Are you using last kaggle environment?

defdet commented 2 months ago

Hey, I've faced the same issue, use !pip install scipy==1.10.1. Older scipy versions work.

jcole75 commented 2 months ago

I do at least have a different error now after installing the specific scipy. I am using the latest environment with TPU enabled.

`ImportError Traceback (most recent call last) Cell In[1], line 1 ----> 1 from EasyDel import ( 2 AutoEasyDelModelForCausalLM, 3 TrainArguments, 4 CausalLanguageModelTrainer, 5 EasyDelOptimizers, 6 EasyDelSchedulers, 7 EasyDelGradientCheckPointers, 8 EasyDelState, 9 EasyDeLXRapTureConfig, 10 get_modules_by_type, 11 easystate_to_huggingface_model, 12 SFTTrainer, 13 conversations_formatting_function, 14 AutoEasyDelConfig 15 ) 16 from datasets import load_dataset 17 from flax.core import FrozenDict

File /usr/local/lib/python3.10/site-packages/EasyDel/init.py:1 ----> 1 from .serve import ( 2 EasyServe as EasyServe, 3 EasyServeConfig as EasyServeConfig, 4 LLMBaseReq as LLMBaseReq, 5 GenerateAPIRequest as GenerateAPIRequest, 6 ConversationItem as ConversationItem, 7 ModelOutput as ModelOutput, 8 BaseModel as BaseModel, 9 EasyClient as EasyClient, 10 GradioUserInference as GradioUserInference, 11 ChatRequest as ChatRequest, 12 InstructRequest as InstructRequest, 13 PyTorchServer as PyTorchServer, 14 PyTorchServerConfig as PyTorchServerConfig, 15 JAXServer as JAXServer, 16 JAXServerConfig as JAXServerConfig, 17 create_generate_function as create_generate_function 18 ) 20 from .modules.llama import ( 21 FlaxLlamaModel as FlaxLlamaModel, 22 FlaxLlamaForCausalLM as FlaxLlamaForCausalLM, (...) 26 VisionLlamaConfig as VisionLlamaConfig 27 ) 28 from .modules.gpt_j import ( 29 GPTJConfig as GPTJConfig, 30 FlaxGPTJForCausalLM as FlaxGPTJForCausalLM, 31 FlaxGPTJModel as FlaxGPTJModel, 32 )

File /usr/local/lib/python3.10/site-packages/EasyDel/serve/init.py:1 ----> 1 from .jax_serve import JAXServer, JAXServerConfig 2 from .torch_serve import PyTorchServer, PyTorchServerConfig 3 from .utils import ChatRequest, InstructRequest

File /usr/local/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:16 13 from fastapi import FastAPI 14 from fjformer import make_shard_and_gather_fns, match_partition_rules, with_sharding_constraint ---> 16 from ..etils.etils import get_logger 17 from ..smi import get_mem, initialise_tracking 18 from jax import numpy as jnp

File /usr/local/lib/python3.10/site-packages/EasyDel/etils/init.py:26 11 from .etils import ( 12 EasyDelGradientCheckPointers, 13 EasyDelOptimizers, (...) 17 AVAILABLE_GRADIENT_CHECKPOINTS 18 ) 20 from .errors import ( 21 EasyDelTimerError, 22 EasyDelRuntimeError, 23 EasyDelSyntaxRuntimeError 24 ) ---> 26 from .easystate import ( 27 EasyDelState 28 ) 30 from .auto_tx import ( 31 get_optimizer_and_scheduler 32 )

File /usr/local/lib/python3.10/site-packages/EasyDel/etils/easystate.py:15 13 from .auto_tx import get_optimizer_and_scheduler 14 from ..etils import AVAILABLE_SCHEDULERS, AVAILABLE_OPTIMIZERS, EasyDelRuntimeError ---> 15 from ..modules.easydel_modelling_utils import EasyDelFlaxPretrainedModel, EasyDelPretrainedConfig 16 from jax.sharding import Mesh, PartitionSpec 17 from jax import numpy as jnp

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/init.py:115 103 from .qwen2_moe import ( 104 Qwen2MoeConfig as Qwen2MoeConfig, 105 FlaxQwen2MoeModel as FlaxQwen2MoeModel, 106 FlaxQwen2MoeForCausalLM as FlaxQwen2MoeForCausalLM 107 ) 108 from .whisper import ( 109 FlaxWhisperForConditionalGeneration as FlaxWhisperForConditionalGeneration, 110 FlaxWhisperForAudioClassification as FlaxWhisperForAudioClassification, 111 FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor, 112 WhisperConfig as WhisperConfig 113 ) --> 115 from .cohere import ( 116 FlaxCohereModel as FlaxCohereModel, 117 CohereConfig as CohereConfig, 118 FlaxCohereForCausalLM as FlaxCohereForCausalLM 119 ) 121 from .auto_easydel_model import ( 122 AutoEasyDelModelForCausalLM as AutoEasyDelModelForCausalLM, 123 AutoEasyDelConfig as AutoEasyDelConfig, 124 AutoShardAndGatherFunctions as AutoShardAndGatherFunctions 125 ) 127 all = ( 128 "FlaxLlamaModel", "FlaxLlamaForCausalLM", "FlaxLlamaForSequenceClassification", "LlamaConfig", 129 "VisionLlamaConfig", "FlaxVisionLlamaForCausalLM", (...) 172 173 )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/init.py:2 1 from .cohere_configuration import CohereConfig ----> 2 from .modelling_cohere_flax import FlaxCohereModel, FlaxCohereForCausalLM 4 all = "CohereConfig", "FlaxCohereModel", "FlaxCohereForCausalLM"

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/modelling_cohere_flax.py:34 22 from ..flax_modelling_utils import ( 23 with_sharding_constraint, 24 get_gradient_checkpoint_policy, (...) 30 block_wise_ffn 31 ) 33 re_mat = flax.linen.partitioning.remat ---> 34 from transformers import CohereForCausalLM 37 class FlaxCohereEmbedding(nn.Module): 38 dtype: jnp.dtype = jnp.float32

ImportError: cannot import name 'CohereForCausalLM' from 'transformers' (/usr/local/lib/python3.10/site-packages/transformers/init.py)`

defdet commented 2 months ago

You know it's pretty funny, I've got those 2 erros in a row today, the same as you. I'm assuming this error is because of the recent updates in EasyDeL, cohere is present in your version of easydel, but not in your version of transformers, you either install the newest transformers (pip install git+https://github.com/huggingface/transformers) that has cohere, or install older EasyDeL.

erfanzar commented 2 months ago

I do at least have a different error now after installing the specific scipy. I am using the latest environment with TPU enabled.

`ImportError Traceback (most recent call last)

Cell In[1], line 1

----> 1 from EasyDel import (

  2     AutoEasyDelModelForCausalLM,

  3     TrainArguments,

  4     CausalLanguageModelTrainer,

  5     EasyDelOptimizers,

  6     EasyDelSchedulers,

  7     EasyDelGradientCheckPointers,

  8     EasyDelState,

  9     EasyDeLXRapTureConfig,

 10     get_modules_by_type,

 11     easystate_to_huggingface_model,

 12     SFTTrainer,

 13     conversations_formatting_function,

 14     AutoEasyDelConfig

 15 )

 16 from datasets import load_dataset

 17 from flax.core import FrozenDict

File /usr/local/lib/python3.10/site-packages/EasyDel/init.py:1

----> 1 from .serve import (

  2     EasyServe as EasyServe,

  3     EasyServeConfig as EasyServeConfig,

  4     LLMBaseReq as LLMBaseReq,

  5     GenerateAPIRequest as GenerateAPIRequest,

  6     ConversationItem as ConversationItem,

  7     ModelOutput as ModelOutput,

  8     BaseModel as BaseModel,

  9     EasyClient as EasyClient,

 10     GradioUserInference as GradioUserInference,

 11     ChatRequest as ChatRequest,

 12     InstructRequest as InstructRequest,

 13     PyTorchServer as PyTorchServer,

 14     PyTorchServerConfig as PyTorchServerConfig,

 15     JAXServer as JAXServer,

 16     JAXServerConfig as JAXServerConfig,

 17     create_generate_function as create_generate_function

 18 )

 20 from .modules.llama import (

 21     FlaxLlamaModel as FlaxLlamaModel,

 22     FlaxLlamaForCausalLM as FlaxLlamaForCausalLM,

(...)

 26     VisionLlamaConfig as VisionLlamaConfig

 27 )

 28 from .modules.gpt_j import (

 29     GPTJConfig as GPTJConfig,

 30     FlaxGPTJForCausalLM as FlaxGPTJForCausalLM,

 31     FlaxGPTJModel as FlaxGPTJModel,

 32 )

File /usr/local/lib/python3.10/site-packages/EasyDel/serve/init.py:1

----> 1 from .jax_serve import JAXServer, JAXServerConfig

  2 from .torch_serve import PyTorchServer, PyTorchServerConfig

  3 from .utils import ChatRequest, InstructRequest

File /usr/local/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:16

 13 from fastapi import FastAPI

 14 from fjformer import make_shard_and_gather_fns, match_partition_rules, with_sharding_constraint

---> 16 from ..etils.etils import get_logger

 17 from ..smi import get_mem, initialise_tracking

 18 from jax import numpy as jnp

File /usr/local/lib/python3.10/site-packages/EasyDel/etils/init.py:26

 11 from .etils import (

 12     EasyDelGradientCheckPointers,

 13     EasyDelOptimizers,

(...)

 17     AVAILABLE_GRADIENT_CHECKPOINTS

 18 )

 20 from .errors import (

 21     EasyDelTimerError,

 22     EasyDelRuntimeError,

 23     EasyDelSyntaxRuntimeError

 24 )

---> 26 from .easystate import (

 27     EasyDelState

 28 )

 30 from .auto_tx import (

 31     get_optimizer_and_scheduler

 32 )

File /usr/local/lib/python3.10/site-packages/EasyDel/etils/easystate.py:15

 13 from .auto_tx import get_optimizer_and_scheduler

 14 from ..etils import AVAILABLE_SCHEDULERS, AVAILABLE_OPTIMIZERS, EasyDelRuntimeError

---> 15 from ..modules.easydel_modelling_utils import EasyDelFlaxPretrainedModel, EasyDelPretrainedConfig

 16 from jax.sharding import Mesh, PartitionSpec

 17 from jax import numpy as jnp

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/init.py:115

103 from .qwen2_moe import (

104     Qwen2MoeConfig as Qwen2MoeConfig,

105     FlaxQwen2MoeModel as FlaxQwen2MoeModel,

106     FlaxQwen2MoeForCausalLM as FlaxQwen2MoeForCausalLM

107 )

108 from .whisper import (

109     FlaxWhisperForConditionalGeneration as FlaxWhisperForConditionalGeneration,

110     FlaxWhisperForAudioClassification as FlaxWhisperForAudioClassification,

111     FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor,

112     WhisperConfig as WhisperConfig

113 )

--> 115 from .cohere import (

116     FlaxCohereModel as FlaxCohereModel,

117     CohereConfig as CohereConfig,

118     FlaxCohereForCausalLM as FlaxCohereForCausalLM

119 )

121 from .auto_easydel_model import (

122     AutoEasyDelModelForCausalLM as AutoEasyDelModelForCausalLM,

123     AutoEasyDelConfig as AutoEasyDelConfig,

124     AutoShardAndGatherFunctions as AutoShardAndGatherFunctions

125 )

127 __all__ = (

128     "FlaxLlamaModel", "FlaxLlamaForCausalLM", "FlaxLlamaForSequenceClassification", "LlamaConfig",

129     "VisionLlamaConfig", "FlaxVisionLlamaForCausalLM",

(...)

172 

173 )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/init.py:2

  1 from .cohere_configuration import CohereConfig

----> 2 from .modelling_cohere_flax import FlaxCohereModel, FlaxCohereForCausalLM

  4 __all__ = "CohereConfig", "FlaxCohereModel", "FlaxCohereForCausalLM"

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/modelling_cohere_flax.py:34

 22 from ..flax_modelling_utils import (

 23     with_sharding_constraint,

 24     get_gradient_checkpoint_policy,

(...)

 30     block_wise_ffn

 31 )

 33 re_mat = flax.linen.partitioning.remat

---> 34 from transformers import CohereForCausalLM

 37 class FlaxCohereEmbedding(nn.Module):

 38     dtype: jnp.dtype = jnp.float32

ImportError: cannot import name 'CohereForCausalLM' from 'transformers' (/usr/local/lib/python3.10/site-packages/transformers/init.py)`

It's fixed now

erfanzar commented 2 months ago

@jcole75 is that fixed?

jcole75 commented 2 months ago

@jcole75 is that fixed?

Getting closer. I changed to microsoft/phi-2 since mistral is gated. Now I get:

KeyError Traceback (most recent call last) Cell In[21], line 1 ----> 1 trainer = SFTTrainer( 2 arguments=train_arguments, 3 train_dataset=dataset_train, 4 eval_dataset=None, 5 tokenizer=tokenizer, 6 dataset_text_field=None, 7 # formatting_func=prompter, 8 formatting_func=lambda x:[conversations_formatting_function(tokenizer, messages_field="conversation")(x)], 9 packing=False, 10 num_of_sequences=2048, 11 checkpoint_path=checkpoint_path 12 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:72, in SFTTrainer.init(self, arguments, tokenizer, train_dataset, eval_dataset, dataset_text_field, packing, formatting_func, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, dataset_kwargs, eval_packing, checkpoint_path, remove_unused_columns, _do_init_fns) 70 dataset_kwargs = {} 71 if train_dataset is not None: ---> 72 train_dataset = self._prepare_dataset( 73 train_dataset, 74 tokenizer, 75 packing, 76 dataset_text_field, 77 arguments.max_sequence_length, 78 formatting_func, 79 num_of_sequences, 80 chars_per_token, 81 remove_unused_columns=remove_unused_columns, 82 **dataset_kwargs, 83 ) 84 if eval_dataset is not None: 85 _multiple = isinstance(eval_dataset, dict)

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:189, in SFTTrainer._prepare_dataset(self, dataset, tokenizer, packing, dataset_text_field, max_seq_length, formatting_func, num_of_sequences, chars_per_token, remove_unused_columns, append_concat_token, add_special_tokens) 186 raise ValueError("The dataset should not be None") 188 if not packing: --> 189 return self._prepare_non_packed_dataloader( 190 tokenizer, 191 dataset, 192 dataset_text_field, 193 max_seq_length, 194 formatting_func, 195 add_special_tokens, 196 remove_unused_columns, 197 ) 199 else: 200 return self._prepare_packed_dataloader( 201 tokenizer, 202 dataset, (...) 209 add_special_tokens, 210 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:261, in SFTTrainer._prepare_non_packed_dataloader(self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func, add_special_tokens, remove_unused_columns) 252 if not remove_unused_columns and len(extra_columns) > 0: 253 warnings.warn( 254 "You passed remove_unused_columns=False on a non-packed dataset. This might create some issues with " 255 "the default collator and yield to errors. If you want to inspect dataset other columns " (...) 258 "unused dataset columns." 259 ) --> 261 tokenized_dataset = dataset.map( 262 tokenize, 263 batched=False, 264 remove_columns=dataset.column_names if remove_unused_columns else None, 265 num_proc=self.dataset_num_proc, 266 batch_size=self.dataset_batch_size, 267 ) 269 return tokenized_dataset

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:592, in transmit_tasks..wrapper(*args, *kwargs) 590 self: "Dataset" = kwargs.pop("self") 591 # apply actual function --> 592 out: Union["Dataset", "DatasetDict"] = func(self, args, **kwargs) 593 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out] 594 for dataset in datasets: 595 # Remove task templates if a column mapping of the template is no longer valid

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:557, in transmit_format..wrapper(*args, *kwargs) 550 self_format = { 551 "type": self._format_type, 552 "format_kwargs": self._format_kwargs, 553 "columns": self._format_columns, 554 "output_all_columns": self._output_all_columns, 555 } 556 # apply actual function --> 557 out: Union["Dataset", "DatasetDict"] = func(self, args, **kwargs) 558 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out] 559 # re-apply format to the output

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:3097, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc) 3090 if transformed_dataset is None: 3091 with logging.tqdm( 3092 disable=not logging.is_progress_bar_enabled(), 3093 unit=" examples", 3094 total=pbar_total, 3095 desc=desc or "Map", 3096 ) as pbar: -> 3097 for rank, done, content in Dataset._map_single(**dataset_kwargs): 3098 if done: 3099 shards_done += 1

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:3450, in Dataset._map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset) 3448 _time = time.time() 3449 for i, example in shard_iterable: -> 3450 example = apply_function_on_filtered_inputs(example, i, offset=offset) 3451 if update_data: 3452 if i == 0:

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:3353, in Dataset._map_single..apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset) 3351 if with_rank: 3352 additional_args += (rank,) -> 3353 processed_inputs = function(fn_args, additional_args, **fn_kwargs) 3354 if isinstance(processed_inputs, LazyDict): 3355 processed_inputs = { 3356 k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format 3357 }

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:226, in SFTTrainer._prepare_non_packed_dataloader..tokenize(element) 225 def tokenize(element): --> 226 inner = element[dataset_text_field] if not use_formatting_func else formatting_func(element) 227 outputs = tokenizer( 228 inner, 229 add_special_tokens=add_special_tokens, (...) 234 return_length=False, 235 ) 237 if use_formatting_func and not self._dataset_sanity_checked:

Cell In[21], line 8, in (x) 1 trainer = SFTTrainer( 2 arguments=train_arguments, 3 train_dataset=dataset_train, 4 eval_dataset=None, 5 tokenizer=tokenizer, 6 dataset_text_field=None, 7 # formatting_func=prompter, ----> 8 formatting_func=lambda x:[conversations_formatting_function(tokenizer, messages_field="conversation")(x)], 9 packing=False, 10 num_of_sequences=2048, 11 checkpoint_path=checkpoint_path 12 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/utils.py:420, in conversations_formatting_function..format_dataset(examples) 419 def format_dataset(examples): --> 420 if isinstance(examples[messages_field][0], list): 421 output_texts = [] 422 for i in range(len(examples[messages_field])):

File /usr/local/lib/python3.10/site-packages/datasets/formatting/formatting.py:270, in LazyDict.getitem(self, key) 269 def getitem(self, key): --> 270 value = self.data[key] 271 if key in self.keys_to_format: 272 value = self.format(key)

KeyError: 'conversation'

defdet commented 2 months ago

Just change message field to 'messages' since that's the correct name of the dataset column. Also might want to take a look at my example, it fully works (with ring attention)

jcole75 commented 2 months ago

Just change message field to 'messages' since that's the correct name of the dataset column. Also might want to take a look at my example, it fully works (with ring attention)

Changing to messages worked. When I tried your notebook, I get an error.

`TypeError Traceback (most recent call last) Cell In[25], line 1 ----> 1 trainer = SFTTrainer( 2 arguments=train_arguments, 3 train_dataset=train_ds, 4 eval_dataset=test_ds, 5 tokenizer=tokenizer, 6 dataset_text_field=None, 7 # formatting_func=prompter, 8 formatting_func=lambda x:[conversations_formatting_function(tokenizer, messages_field="messages")(x)], 9 packing=False, 10 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:112, in SFTTrainer.init(self, arguments, tokenizer, train_dataset, eval_dataset, dataset_text_field, packing, formatting_func, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, dataset_kwargs, eval_packing, checkpoint_path, remove_unused_columns, _do_init_fns) 105 if tokenizer.padding_side is not None and tokenizer.padding_side != "right": 106 warnings.warn( 107 "You passed a tokenizer with padding_side not equal to right to the SFTTrainer. This might lead " 108 "to some unexpected behaviour due to overflow issues when training a model in half-precision. " 109 "You might consider adding tokenizer.padding_side = 'right' to your code." 110 ) --> 112 super().init( 113 arguments=arguments, 114 dataset_train=train_dataset, 115 dataset_eval=eval_dataset, 116 finetune=True, 117 checkpoint_path=checkpoint_path, 118 _do_init_fns=_do_init_fns, 119 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:153, in BaseTrainer.init(self, arguments, dataset_train, dataset_eval, finetune, checkpoint_path, _do_init_fns) 147 prefix_print( 148 "Warning", 149 "In case of using finetune = True and Passing checkpoint_path = None" 150 " you should pass parameters in train function" 151 ) 152 if _do_init_fns: --> 153 self.initialize_trainer_utils() 154 else: 155 prefix_print( 156 "Warning", 157 "you have set _do_init_fns = False so function will not me initialized you have " 158 f"to do in manually (simply with trainer.initialize_trainer_utils() )" 159 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:233, in BaseTrainer.initialize_trainer_utils(self) 230 self.timer.log(["configure dataloaders"]) 232 self.timer("configure Model, Optimizer, Scheduler and Config").start() --> 233 model_configurations = self.configure_model() 234 model = model_configurations.model 235 tx = model_configurations.tx

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:384, in BaseTrainer.configure_model(self) 377 assert self.arguments.custom_rule is not None, ( 378 "if you are using custom model to init you must" 379 " pass custom_rule for partition rules " 380 ) 382 self.arguments.configs_to_initialize_model_class["config"].axis_dims = self.arguments.sharding_array --> 384 model = self.arguments.model_class( 385 **self.arguments.configs_to_initialize_model_class, 386 _do_init=False 387 ) 389 config = self.arguments.configs_to_initialize_model_class["config"] 391 else:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:586, in FlaxQwen2PreTrainedModel.init(self, config, input_shape, seed, dtype, _do_init, kwargs) 568 """ 569 The init function is called when the class is instantiated. 570 It sets up the instance of the class, and defines what happens when it's created. (...) 583 584 """ 585 module = self.module_class(config=config, dtype=dtype, kwargs) --> 586 super().init(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py:447, in EasyDelFlaxPretrainedModel.init(self, config, module, input_shape, seed, dtype, param_dtype, precision, _do_init) 436 def init( 437 self, 438 config: PretrainedConfig, (...) 445 _do_init: bool = True, 446 ): --> 447 super().init( 448 config=config, 449 module=module, 450 input_shape=input_shape, 451 seed=seed, 452 dtype=dtype, 453 _do_init=_do_init 454 )

File /usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:224, in FlaxPreTrainedModel.init(self, config, module, input_shape, seed, dtype, _do_init) 222 else: 223 init_fn = partial(self.init_weights, input_shape=input_shape) --> 224 params_shape_tree = jax.eval_shape(init_fn, self.key) 226 logger.info( 227 "Model weights are not initialized as _do_init is set to False. " 228 f"Make sure to call {self.__class__.__name__}.init_weights manually to initialize the weights." 229 ) 231 # get the shape of the parameters

[... skipping hidden 8 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:619, in FlaxQwen2PreTrainedModel.init_weights(self, rng, input_shape, params) 609 module_init_outputs = self.module.init( 610 rngs, 611 input_ids, (...) 616 return_dict=False, 617 ) 618 else: --> 619 module_init_outputs = self.module.init( 620 rngs, input_ids, attention_mask, position_ids, return_dict=False) 622 random_params = module_init_outputs["params"] 624 if params is not None:

[... skipping hidden 9 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:1063, in FlaxQwen2ForCausalLMModule.call(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding) 1058 if position_ids is None: 1059 position_ids = jnp.broadcast_to( 1060 jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), 1061 (batch_size, seq_length) 1062 ) -> 1063 outputs = self.model( 1064 input_ids, 1065 attention_mask, 1066 position_ids, 1067 deterministic=deterministic, 1068 init_cache=init_cache, 1069 output_attentions=output_attentions, 1070 output_hidden_states=output_hidden_states, 1071 return_dict=return_dict, 1072 extra_embedding=extra_embedding 1073 ) 1075 hidden_states = outputs[0] 1077 if self.config.tie_word_embeddings:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:960, in FlaxQwen2Module.call(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding) 956 inputs_embeds = inputs_embeds + extra_embedding if extra_embedding is not None else inputs_embeds 957 hidden_states = self.dropout( 958 inputs_embeds, deterministic=deterministic) --> 960 outputs = self.layers( 961 hidden_states=hidden_states, 962 freq_cis=self.freq_cis, 963 attention_mask=attention_mask, 964 position_ids=position_ids, 965 causal_mask=self.causal_mask, 966 deterministic=deterministic, 967 init_cache=init_cache, 968 output_attentions=output_attentions, 969 output_hidden_states=output_hidden_states, 970 return_dict=return_dict, 971 ) 973 hidden_states = outputs[0] 974 hidden_states = self.norm(hidden_states)

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:839, in FlaxQwen2BlockCollection.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict) 836 if output_hidden_states: 837 all_hidden_states += (hidden_states,) --> 839 layer_outputs = block( 840 hidden_states=hidden_states, 841 freq_cis=freq_cis, 842 attention_mask=attention_mask, 843 position_ids=position_ids, 844 causal_mask=causal_mask, 845 deterministic=deterministic, 846 init_cache=init_cache, 847 output_attentions=output_attentions, 848 fcm_mask=fcm_mask, 849 ) 850 hidden_states = layer_outputs[0] 852 if output_attentions:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:519, in FlaxQwen2Block.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask) 486 def call( 487 self, 488 hidden_states: chex.Array, (...) 497 fcm_mask: Optional[jnp.ndarray] = None, 498 ): 499 """ 500 The call function is the main function of a TransformerEncoderLayer. 501 It takes in hidden states, frequency-domain inputs, and masks as input. It then (...) 517 518 """ --> 519 attn_outputs = self.self_attn( 520 self.input_layernorm(hidden_states), 521 freq_cis, 522 attention_mask, 523 position_ids, 524 causal_mask, 525 segment_ids, 526 deterministic, 527 init_cache, 528 output_attentions, 529 fcm_mask, 530 ) 531 attn_output = attn_outputs[0] 532 hidden_states = hidden_states + attn_output

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:567, in core_remat_static..inner(scope_fn, repack_fn, variable_groups, rng_groups, args) 564 y = fn(scope, args) 565 return y, repack_fn(scope) --> 567 return rematted(variable_groups, rng_groups, *dyn_args)

[... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:564, in core_remat_static..inner..rematted(variable_groups, rng_groups, dyn_args) 562 args = _repack_remat_args(dyn_args, static_args) 563 scope = scope_fn(variable_groups, rng_groups) --> 564 y = fn(scope, args) 565 return y, repack_fn(scope)

[... skipping hidden 3 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:398, in FlaxQwen2Attention.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask) 390 attention_bias = lax.select( 391 attention_mask > 0, 392 jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 393 jnp.full(attention_mask.shape, jnp.finfo( 394 self.dtype).min).astype(self.dtype), 395 ) 396 query_length, key_length = query_states.shape[1], key_states.shape[1] --> 398 attentions = self.attention_performer.call( 399 query_states=query_states, 400 key_states=key_states, 401 value_states=value_states, 402 bias=attention_bias, 403 attention_mask=attention_mask, 404 causal=False, 405 dropout_rng=dropout_rng, 406 deterministic=deterministic, 407 query_sequence_length=query_length, 408 key_value_sequence_length=key_length, 409 uses_cache=self.has_variable("cache", "cached_key") or init_cache, 410 segment_ids=segment_ids, 411 ) 414 attn_output = self._merge_heads(attentions.attention_outputs) 415 if self.config.shard_attention_computation:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:286, in AttentionModule.call(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, attention_mask, segment_ids, causal, deterministic, dropout_rng, uses_cache) 277 return self.sharded_vanilla_attention( 278 query_states=query_states, 279 key_states=key_states, (...) 283 deterministic=deterministic, 284 ) 285 elif self.attn_mechanism == "ring": --> 286 return self.ring_attention( 287 query_states=query_states, 288 key_states=key_states, 289 value_states=value_states, 290 bias=bias, 291 dropout_rng=dropout_rng, 292 deterministic=deterministic, 293 query_sequence_length=query_sequence_length, 294 segment_ids=segment_ids, 295 attention_mask=attention_mask 296 ) 298 elif self.attn_mechanism == "splash": 299 return self.splash_attention( 300 query_states=query_states, 301 key_states=key_states, 302 value_states=value_states, 303 segment_ids=segment_ids, 304 )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:428, in AttentionModule.ring_attention(self, query_states, key_states, value_states, query_sequence_length, bias, attention_mask, deterministic, dropout_rng, segment_ids) 415 query_sequence_partition = None if query_states.shape[1] == 1 else "sp" 416 ring_attention_sharded = shard_map( 417 partial(ring_attention_standard, axis_name=self.axis_name), 418 mesh=self.mesh, (...) 426 check_rep=False 427 ) --> 428 attn_output = ring_attention_sharded( 429 query_states, key_states, value_states, attention_mask 430 ) 431 return AttentionOutput( 432 attention_weights=None, 433 attention_outputs=attn_output 434 )

[... skipping hidden 9 frame]

File /usr/local/lib/python3.10/inspect.py:3186, in Signature.bind(self, *args, *kwargs) 3181 def bind(self, /, args, **kwargs): 3182 """Get a BoundArguments object, that maps the passed args 3183 and kwargs to the function's signature. Raises TypeError 3184 if the passed arguments can not be bound. 3185 """ -> 3186 return self._bind(args, kwargs)

File /usr/local/lib/python3.10/inspect.py:3101, in Signature._bind(self, args, kwargs, partial) 3099 msg = 'missing a required argument: {arg!r}' 3100 msg = msg.format(arg=param.name) -> 3101 raise TypeError(msg) from None 3102 else: 3103 # We have a positional argument to process 3104 try:

TypeError: missing a required argument: 'scale'`

erfanzar commented 2 months ago

@jcole75 hi, and thank for finding the ring attention. it's fixed now and you can try that but since you are using TPU-v3 i don't recommend using ring attention

you can use sharded_vanilla which is faster and more efficient

erfanzar commented 2 months ago

@jcole75 is your issue fixed?