Closed jcole75 closed 2 months ago
hi, thanks for using EasyDeL!
your problem will be fixed with just restarting the session
@jcole75 is your issue fixed?
Unfortunately, no. I restart it and I get the same issue.
Are you using last kaggle environment?
Hey, I've faced the same issue, use !pip install scipy==1.10.1. Older scipy versions work.
I do at least have a different error now after installing the specific scipy. I am using the latest environment with TPU enabled.
`ImportError Traceback (most recent call last) Cell In[1], line 1 ----> 1 from EasyDel import ( 2 AutoEasyDelModelForCausalLM, 3 TrainArguments, 4 CausalLanguageModelTrainer, 5 EasyDelOptimizers, 6 EasyDelSchedulers, 7 EasyDelGradientCheckPointers, 8 EasyDelState, 9 EasyDeLXRapTureConfig, 10 get_modules_by_type, 11 easystate_to_huggingface_model, 12 SFTTrainer, 13 conversations_formatting_function, 14 AutoEasyDelConfig 15 ) 16 from datasets import load_dataset 17 from flax.core import FrozenDict
File /usr/local/lib/python3.10/site-packages/EasyDel/init.py:1 ----> 1 from .serve import ( 2 EasyServe as EasyServe, 3 EasyServeConfig as EasyServeConfig, 4 LLMBaseReq as LLMBaseReq, 5 GenerateAPIRequest as GenerateAPIRequest, 6 ConversationItem as ConversationItem, 7 ModelOutput as ModelOutput, 8 BaseModel as BaseModel, 9 EasyClient as EasyClient, 10 GradioUserInference as GradioUserInference, 11 ChatRequest as ChatRequest, 12 InstructRequest as InstructRequest, 13 PyTorchServer as PyTorchServer, 14 PyTorchServerConfig as PyTorchServerConfig, 15 JAXServer as JAXServer, 16 JAXServerConfig as JAXServerConfig, 17 create_generate_function as create_generate_function 18 ) 20 from .modules.llama import ( 21 FlaxLlamaModel as FlaxLlamaModel, 22 FlaxLlamaForCausalLM as FlaxLlamaForCausalLM, (...) 26 VisionLlamaConfig as VisionLlamaConfig 27 ) 28 from .modules.gpt_j import ( 29 GPTJConfig as GPTJConfig, 30 FlaxGPTJForCausalLM as FlaxGPTJForCausalLM, 31 FlaxGPTJModel as FlaxGPTJModel, 32 )
File /usr/local/lib/python3.10/site-packages/EasyDel/serve/init.py:1 ----> 1 from .jax_serve import JAXServer, JAXServerConfig 2 from .torch_serve import PyTorchServer, PyTorchServerConfig 3 from .utils import ChatRequest, InstructRequest
File /usr/local/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:16 13 from fastapi import FastAPI 14 from fjformer import make_shard_and_gather_fns, match_partition_rules, with_sharding_constraint ---> 16 from ..etils.etils import get_logger 17 from ..smi import get_mem, initialise_tracking 18 from jax import numpy as jnp
File /usr/local/lib/python3.10/site-packages/EasyDel/etils/init.py:26 11 from .etils import ( 12 EasyDelGradientCheckPointers, 13 EasyDelOptimizers, (...) 17 AVAILABLE_GRADIENT_CHECKPOINTS 18 ) 20 from .errors import ( 21 EasyDelTimerError, 22 EasyDelRuntimeError, 23 EasyDelSyntaxRuntimeError 24 ) ---> 26 from .easystate import ( 27 EasyDelState 28 ) 30 from .auto_tx import ( 31 get_optimizer_and_scheduler 32 )
File /usr/local/lib/python3.10/site-packages/EasyDel/etils/easystate.py:15 13 from .auto_tx import get_optimizer_and_scheduler 14 from ..etils import AVAILABLE_SCHEDULERS, AVAILABLE_OPTIMIZERS, EasyDelRuntimeError ---> 15 from ..modules.easydel_modelling_utils import EasyDelFlaxPretrainedModel, EasyDelPretrainedConfig 16 from jax.sharding import Mesh, PartitionSpec 17 from jax import numpy as jnp
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/init.py:115 103 from .qwen2_moe import ( 104 Qwen2MoeConfig as Qwen2MoeConfig, 105 FlaxQwen2MoeModel as FlaxQwen2MoeModel, 106 FlaxQwen2MoeForCausalLM as FlaxQwen2MoeForCausalLM 107 ) 108 from .whisper import ( 109 FlaxWhisperForConditionalGeneration as FlaxWhisperForConditionalGeneration, 110 FlaxWhisperForAudioClassification as FlaxWhisperForAudioClassification, 111 FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor, 112 WhisperConfig as WhisperConfig 113 ) --> 115 from .cohere import ( 116 FlaxCohereModel as FlaxCohereModel, 117 CohereConfig as CohereConfig, 118 FlaxCohereForCausalLM as FlaxCohereForCausalLM 119 ) 121 from .auto_easydel_model import ( 122 AutoEasyDelModelForCausalLM as AutoEasyDelModelForCausalLM, 123 AutoEasyDelConfig as AutoEasyDelConfig, 124 AutoShardAndGatherFunctions as AutoShardAndGatherFunctions 125 ) 127 all = ( 128 "FlaxLlamaModel", "FlaxLlamaForCausalLM", "FlaxLlamaForSequenceClassification", "LlamaConfig", 129 "VisionLlamaConfig", "FlaxVisionLlamaForCausalLM", (...) 172 173 )
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/init.py:2 1 from .cohere_configuration import CohereConfig ----> 2 from .modelling_cohere_flax import FlaxCohereModel, FlaxCohereForCausalLM 4 all = "CohereConfig", "FlaxCohereModel", "FlaxCohereForCausalLM"
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/modelling_cohere_flax.py:34 22 from ..flax_modelling_utils import ( 23 with_sharding_constraint, 24 get_gradient_checkpoint_policy, (...) 30 block_wise_ffn 31 ) 33 re_mat = flax.linen.partitioning.remat ---> 34 from transformers import CohereForCausalLM 37 class FlaxCohereEmbedding(nn.Module): 38 dtype: jnp.dtype = jnp.float32
ImportError: cannot import name 'CohereForCausalLM' from 'transformers' (/usr/local/lib/python3.10/site-packages/transformers/init.py)`
You know it's pretty funny, I've got those 2 erros in a row today, the same as you. I'm assuming this error is because of the recent updates in EasyDeL, cohere is present in your version of easydel, but not in your version of transformers, you either install the newest transformers (pip install git+https://github.com/huggingface/transformers) that has cohere, or install older EasyDeL.
I do at least have a different error now after installing the specific scipy. I am using the latest environment with TPU enabled.
`ImportError Traceback (most recent call last)
Cell In[1], line 1
----> 1 from EasyDel import (
2 AutoEasyDelModelForCausalLM, 3 TrainArguments, 4 CausalLanguageModelTrainer, 5 EasyDelOptimizers, 6 EasyDelSchedulers, 7 EasyDelGradientCheckPointers, 8 EasyDelState, 9 EasyDeLXRapTureConfig, 10 get_modules_by_type, 11 easystate_to_huggingface_model, 12 SFTTrainer, 13 conversations_formatting_function, 14 AutoEasyDelConfig 15 ) 16 from datasets import load_dataset 17 from flax.core import FrozenDict
File /usr/local/lib/python3.10/site-packages/EasyDel/init.py:1
----> 1 from .serve import (
2 EasyServe as EasyServe, 3 EasyServeConfig as EasyServeConfig, 4 LLMBaseReq as LLMBaseReq, 5 GenerateAPIRequest as GenerateAPIRequest, 6 ConversationItem as ConversationItem, 7 ModelOutput as ModelOutput, 8 BaseModel as BaseModel, 9 EasyClient as EasyClient, 10 GradioUserInference as GradioUserInference, 11 ChatRequest as ChatRequest, 12 InstructRequest as InstructRequest, 13 PyTorchServer as PyTorchServer, 14 PyTorchServerConfig as PyTorchServerConfig, 15 JAXServer as JAXServer, 16 JAXServerConfig as JAXServerConfig, 17 create_generate_function as create_generate_function 18 ) 20 from .modules.llama import ( 21 FlaxLlamaModel as FlaxLlamaModel, 22 FlaxLlamaForCausalLM as FlaxLlamaForCausalLM,
(...)
26 VisionLlamaConfig as VisionLlamaConfig 27 ) 28 from .modules.gpt_j import ( 29 GPTJConfig as GPTJConfig, 30 FlaxGPTJForCausalLM as FlaxGPTJForCausalLM, 31 FlaxGPTJModel as FlaxGPTJModel, 32 )
File /usr/local/lib/python3.10/site-packages/EasyDel/serve/init.py:1
----> 1 from .jax_serve import JAXServer, JAXServerConfig
2 from .torch_serve import PyTorchServer, PyTorchServerConfig 3 from .utils import ChatRequest, InstructRequest
File /usr/local/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:16
13 from fastapi import FastAPI 14 from fjformer import make_shard_and_gather_fns, match_partition_rules, with_sharding_constraint
---> 16 from ..etils.etils import get_logger
17 from ..smi import get_mem, initialise_tracking 18 from jax import numpy as jnp
File /usr/local/lib/python3.10/site-packages/EasyDel/etils/init.py:26
11 from .etils import ( 12 EasyDelGradientCheckPointers, 13 EasyDelOptimizers,
(...)
17 AVAILABLE_GRADIENT_CHECKPOINTS 18 ) 20 from .errors import ( 21 EasyDelTimerError, 22 EasyDelRuntimeError, 23 EasyDelSyntaxRuntimeError 24 )
---> 26 from .easystate import (
27 EasyDelState 28 ) 30 from .auto_tx import ( 31 get_optimizer_and_scheduler 32 )
File /usr/local/lib/python3.10/site-packages/EasyDel/etils/easystate.py:15
13 from .auto_tx import get_optimizer_and_scheduler 14 from ..etils import AVAILABLE_SCHEDULERS, AVAILABLE_OPTIMIZERS, EasyDelRuntimeError
---> 15 from ..modules.easydel_modelling_utils import EasyDelFlaxPretrainedModel, EasyDelPretrainedConfig
16 from jax.sharding import Mesh, PartitionSpec 17 from jax import numpy as jnp
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/init.py:115
103 from .qwen2_moe import ( 104 Qwen2MoeConfig as Qwen2MoeConfig, 105 FlaxQwen2MoeModel as FlaxQwen2MoeModel, 106 FlaxQwen2MoeForCausalLM as FlaxQwen2MoeForCausalLM 107 ) 108 from .whisper import ( 109 FlaxWhisperForConditionalGeneration as FlaxWhisperForConditionalGeneration, 110 FlaxWhisperForAudioClassification as FlaxWhisperForAudioClassification, 111 FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor, 112 WhisperConfig as WhisperConfig 113 )
--> 115 from .cohere import (
116 FlaxCohereModel as FlaxCohereModel, 117 CohereConfig as CohereConfig, 118 FlaxCohereForCausalLM as FlaxCohereForCausalLM 119 ) 121 from .auto_easydel_model import ( 122 AutoEasyDelModelForCausalLM as AutoEasyDelModelForCausalLM, 123 AutoEasyDelConfig as AutoEasyDelConfig, 124 AutoShardAndGatherFunctions as AutoShardAndGatherFunctions 125 ) 127 __all__ = ( 128 "FlaxLlamaModel", "FlaxLlamaForCausalLM", "FlaxLlamaForSequenceClassification", "LlamaConfig", 129 "VisionLlamaConfig", "FlaxVisionLlamaForCausalLM",
(...)
172 173 )
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/init.py:2
1 from .cohere_configuration import CohereConfig
----> 2 from .modelling_cohere_flax import FlaxCohereModel, FlaxCohereForCausalLM
4 __all__ = "CohereConfig", "FlaxCohereModel", "FlaxCohereForCausalLM"
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/cohere/modelling_cohere_flax.py:34
22 from ..flax_modelling_utils import ( 23 with_sharding_constraint, 24 get_gradient_checkpoint_policy,
(...)
30 block_wise_ffn 31 ) 33 re_mat = flax.linen.partitioning.remat
---> 34 from transformers import CohereForCausalLM
37 class FlaxCohereEmbedding(nn.Module): 38 dtype: jnp.dtype = jnp.float32
ImportError: cannot import name 'CohereForCausalLM' from 'transformers' (/usr/local/lib/python3.10/site-packages/transformers/init.py)`
It's fixed now
@jcole75 is that fixed?
@jcole75 is that fixed?
Getting closer. I changed to microsoft/phi-2 since mistral is gated. Now I get:
KeyError Traceback (most recent call last) Cell In[21], line 1 ----> 1 trainer = SFTTrainer( 2 arguments=train_arguments, 3 train_dataset=dataset_train, 4 eval_dataset=None, 5 tokenizer=tokenizer, 6 dataset_text_field=None, 7 # formatting_func=prompter, 8 formatting_func=lambda x:[conversations_formatting_function(tokenizer, messages_field="conversation")(x)], 9 packing=False, 10 num_of_sequences=2048, 11 checkpoint_path=checkpoint_path 12 )
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:72, in SFTTrainer.init(self, arguments, tokenizer, train_dataset, eval_dataset, dataset_text_field, packing, formatting_func, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, dataset_kwargs, eval_packing, checkpoint_path, remove_unused_columns, _do_init_fns) 70 dataset_kwargs = {} 71 if train_dataset is not None: ---> 72 train_dataset = self._prepare_dataset( 73 train_dataset, 74 tokenizer, 75 packing, 76 dataset_text_field, 77 arguments.max_sequence_length, 78 formatting_func, 79 num_of_sequences, 80 chars_per_token, 81 remove_unused_columns=remove_unused_columns, 82 **dataset_kwargs, 83 ) 84 if eval_dataset is not None: 85 _multiple = isinstance(eval_dataset, dict)
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:189, in SFTTrainer._prepare_dataset(self, dataset, tokenizer, packing, dataset_text_field, max_seq_length, formatting_func, num_of_sequences, chars_per_token, remove_unused_columns, append_concat_token, add_special_tokens) 186 raise ValueError("The dataset should not be None") 188 if not packing: --> 189 return self._prepare_non_packed_dataloader( 190 tokenizer, 191 dataset, 192 dataset_text_field, 193 max_seq_length, 194 formatting_func, 195 add_special_tokens, 196 remove_unused_columns, 197 ) 199 else: 200 return self._prepare_packed_dataloader( 201 tokenizer, 202 dataset, (...) 209 add_special_tokens, 210 )
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:261, in SFTTrainer._prepare_non_packed_dataloader(self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func, add_special_tokens, remove_unused_columns)
252 if not remove_unused_columns and len(extra_columns) > 0:
253 warnings.warn(
254 "You passed remove_unused_columns=False
on a non-packed dataset. This might create some issues with "
255 "the default collator and yield to errors. If you want to inspect dataset other columns "
(...)
258 "unused dataset columns."
259 )
--> 261 tokenized_dataset = dataset.map(
262 tokenize,
263 batched=False,
264 remove_columns=dataset.column_names if remove_unused_columns else None,
265 num_proc=self.dataset_num_proc,
266 batch_size=self.dataset_batch_size,
267 )
269 return tokenized_dataset
File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:592, in transmit_tasks.
File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:557, in transmit_format.
File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:3097, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc) 3090 if transformed_dataset is None: 3091 with logging.tqdm( 3092 disable=not logging.is_progress_bar_enabled(), 3093 unit=" examples", 3094 total=pbar_total, 3095 desc=desc or "Map", 3096 ) as pbar: -> 3097 for rank, done, content in Dataset._map_single(**dataset_kwargs): 3098 if done: 3099 shards_done += 1
File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:3450, in Dataset._map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset) 3448 _time = time.time() 3449 for i, example in shard_iterable: -> 3450 example = apply_function_on_filtered_inputs(example, i, offset=offset) 3451 if update_data: 3452 if i == 0:
File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:3353, in Dataset._map_single.
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:226, in SFTTrainer._prepare_non_packed_dataloader.
Cell In[21], line 8, in
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/utils.py:420, in conversations_formatting_function.
File /usr/local/lib/python3.10/site-packages/datasets/formatting/formatting.py:270, in LazyDict.getitem(self, key) 269 def getitem(self, key): --> 270 value = self.data[key] 271 if key in self.keys_to_format: 272 value = self.format(key)
KeyError: 'conversation'
Just change message field to 'messages' since that's the correct name of the dataset column. Also might want to take a look at my example, it fully works (with ring attention)
Just change message field to 'messages' since that's the correct name of the dataset column. Also might want to take a look at my example, it fully works (with ring attention)
Changing to messages worked. When I tried your notebook, I get an error.
`TypeError Traceback (most recent call last) Cell In[25], line 1 ----> 1 trainer = SFTTrainer( 2 arguments=train_arguments, 3 train_dataset=train_ds, 4 eval_dataset=test_ds, 5 tokenizer=tokenizer, 6 dataset_text_field=None, 7 # formatting_func=prompter, 8 formatting_func=lambda x:[conversations_formatting_function(tokenizer, messages_field="messages")(x)], 9 packing=False, 10 )
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/sft/stf_trainer.py:112, in SFTTrainer.init(self, arguments, tokenizer, train_dataset, eval_dataset, dataset_text_field, packing, formatting_func, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, dataset_kwargs, eval_packing, checkpoint_path, remove_unused_columns, _do_init_fns)
105 if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
106 warnings.warn(
107 "You passed a tokenizer with padding_side
not equal to right
to the SFTTrainer. This might lead "
108 "to some unexpected behaviour due to overflow issues when training a model in half-precision. "
109 "You might consider adding tokenizer.padding_side = 'right'
to your code."
110 )
--> 112 super().init(
113 arguments=arguments,
114 dataset_train=train_dataset,
115 dataset_eval=eval_dataset,
116 finetune=True,
117 checkpoint_path=checkpoint_path,
118 _do_init_fns=_do_init_fns,
119 )
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:153, in BaseTrainer.init(self, arguments, dataset_train, dataset_eval, finetune, checkpoint_path, _do_init_fns)
147 prefix_print(
148 "Warning",
149 "In case of using finetune = True
and Passing checkpoint_path = None
"
150 " you should pass parameters in train function"
151 )
152 if _do_init_fns:
--> 153 self.initialize_trainer_utils()
154 else:
155 prefix_print(
156 "Warning",
157 "you have set _do_init_fns = False
so function will not me initialized you have "
158 f"to do in manually (simply with trainer.initialize_trainer_utils()
)"
159 )
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:233, in BaseTrainer.initialize_trainer_utils(self) 230 self.timer.log(["configure dataloaders"]) 232 self.timer("configure Model, Optimizer, Scheduler and Config").start() --> 233 model_configurations = self.configure_model() 234 model = model_configurations.model 235 tx = model_configurations.tx
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:384, in BaseTrainer.configure_model(self) 377 assert self.arguments.custom_rule is not None, ( 378 "if you are using custom model to init you must" 379 " pass custom_rule for partition rules " 380 ) 382 self.arguments.configs_to_initialize_model_class["config"].axis_dims = self.arguments.sharding_array --> 384 model = self.arguments.model_class( 385 **self.arguments.configs_to_initialize_model_class, 386 _do_init=False 387 ) 389 config = self.arguments.configs_to_initialize_model_class["config"] 391 else:
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:586, in FlaxQwen2PreTrainedModel.init(self, config, input_shape, seed, dtype, _do_init, kwargs) 568 """ 569 The init function is called when the class is instantiated. 570 It sets up the instance of the class, and defines what happens when it's created. (...) 583 584 """ 585 module = self.module_class(config=config, dtype=dtype, kwargs) --> 586 super().init(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py:447, in EasyDelFlaxPretrainedModel.init(self, config, module, input_shape, seed, dtype, param_dtype, precision, _do_init) 436 def init( 437 self, 438 config: PretrainedConfig, (...) 445 _do_init: bool = True, 446 ): --> 447 super().init( 448 config=config, 449 module=module, 450 input_shape=input_shape, 451 seed=seed, 452 dtype=dtype, 453 _do_init=_do_init 454 )
File /usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:224, in FlaxPreTrainedModel.init(self, config, module, input_shape, seed, dtype, _do_init)
222 else:
223 init_fn = partial(self.init_weights, input_shape=input_shape)
--> 224 params_shape_tree = jax.eval_shape(init_fn, self.key)
226 logger.info(
227 "Model weights are not initialized as _do_init
is set to False
. "
228 f"Make sure to call {self.__class__.__name__}.init_weights
manually to initialize the weights."
229 )
231 # get the shape of the parameters
[... skipping hidden 8 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:619, in FlaxQwen2PreTrainedModel.init_weights(self, rng, input_shape, params) 609 module_init_outputs = self.module.init( 610 rngs, 611 input_ids, (...) 616 return_dict=False, 617 ) 618 else: --> 619 module_init_outputs = self.module.init( 620 rngs, input_ids, attention_mask, position_ids, return_dict=False) 622 random_params = module_init_outputs["params"] 624 if params is not None:
[... skipping hidden 9 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:1063, in FlaxQwen2ForCausalLMModule.call(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding) 1058 if position_ids is None: 1059 position_ids = jnp.broadcast_to( 1060 jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), 1061 (batch_size, seq_length) 1062 ) -> 1063 outputs = self.model( 1064 input_ids, 1065 attention_mask, 1066 position_ids, 1067 deterministic=deterministic, 1068 init_cache=init_cache, 1069 output_attentions=output_attentions, 1070 output_hidden_states=output_hidden_states, 1071 return_dict=return_dict, 1072 extra_embedding=extra_embedding 1073 ) 1075 hidden_states = outputs[0] 1077 if self.config.tie_word_embeddings:
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:960, in FlaxQwen2Module.call(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding) 956 inputs_embeds = inputs_embeds + extra_embedding if extra_embedding is not None else inputs_embeds 957 hidden_states = self.dropout( 958 inputs_embeds, deterministic=deterministic) --> 960 outputs = self.layers( 961 hidden_states=hidden_states, 962 freq_cis=self.freq_cis, 963 attention_mask=attention_mask, 964 position_ids=position_ids, 965 causal_mask=self.causal_mask, 966 deterministic=deterministic, 967 init_cache=init_cache, 968 output_attentions=output_attentions, 969 output_hidden_states=output_hidden_states, 970 return_dict=return_dict, 971 ) 973 hidden_states = outputs[0] 974 hidden_states = self.norm(hidden_states)
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:839, in FlaxQwen2BlockCollection.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict) 836 if output_hidden_states: 837 all_hidden_states += (hidden_states,) --> 839 layer_outputs = block( 840 hidden_states=hidden_states, 841 freq_cis=freq_cis, 842 attention_mask=attention_mask, 843 position_ids=position_ids, 844 causal_mask=causal_mask, 845 deterministic=deterministic, 846 init_cache=init_cache, 847 output_attentions=output_attentions, 848 fcm_mask=fcm_mask, 849 ) 850 hidden_states = layer_outputs[0] 852 if output_attentions:
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:519, in FlaxQwen2Block.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask) 486 def call( 487 self, 488 hidden_states: chex.Array, (...) 497 fcm_mask: Optional[jnp.ndarray] = None, 498 ): 499 """ 500 The call function is the main function of a TransformerEncoderLayer. 501 It takes in hidden states, frequency-domain inputs, and masks as input. It then (...) 517 518 """ --> 519 attn_outputs = self.self_attn( 520 self.input_layernorm(hidden_states), 521 freq_cis, 522 attention_mask, 523 position_ids, 524 causal_mask, 525 segment_ids, 526 deterministic, 527 init_cache, 528 output_attentions, 529 fcm_mask, 530 ) 531 attn_output = attn_outputs[0] 532 hidden_states = hidden_states + attn_output
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:567, in core_remat_static.
[... skipping hidden 7 frame]
File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:564, in core_remat_static.
[... skipping hidden 3 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/qwen2/modelling_qwen_flax.py:398, in FlaxQwen2Attention.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask) 390 attention_bias = lax.select( 391 attention_mask > 0, 392 jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 393 jnp.full(attention_mask.shape, jnp.finfo( 394 self.dtype).min).astype(self.dtype), 395 ) 396 query_length, key_length = query_states.shape[1], key_states.shape[1] --> 398 attentions = self.attention_performer.call( 399 query_states=query_states, 400 key_states=key_states, 401 value_states=value_states, 402 bias=attention_bias, 403 attention_mask=attention_mask, 404 causal=False, 405 dropout_rng=dropout_rng, 406 deterministic=deterministic, 407 query_sequence_length=query_length, 408 key_value_sequence_length=key_length, 409 uses_cache=self.has_variable("cache", "cached_key") or init_cache, 410 segment_ids=segment_ids, 411 ) 414 attn_output = self._merge_heads(attentions.attention_outputs) 415 if self.config.shard_attention_computation:
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:286, in AttentionModule.call(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, attention_mask, segment_ids, causal, deterministic, dropout_rng, uses_cache) 277 return self.sharded_vanilla_attention( 278 query_states=query_states, 279 key_states=key_states, (...) 283 deterministic=deterministic, 284 ) 285 elif self.attn_mechanism == "ring": --> 286 return self.ring_attention( 287 query_states=query_states, 288 key_states=key_states, 289 value_states=value_states, 290 bias=bias, 291 dropout_rng=dropout_rng, 292 deterministic=deterministic, 293 query_sequence_length=query_sequence_length, 294 segment_ids=segment_ids, 295 attention_mask=attention_mask 296 ) 298 elif self.attn_mechanism == "splash": 299 return self.splash_attention( 300 query_states=query_states, 301 key_states=key_states, 302 value_states=value_states, 303 segment_ids=segment_ids, 304 )
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:428, in AttentionModule.ring_attention(self, query_states, key_states, value_states, query_sequence_length, bias, attention_mask, deterministic, dropout_rng, segment_ids) 415 query_sequence_partition = None if query_states.shape[1] == 1 else "sp" 416 ring_attention_sharded = shard_map( 417 partial(ring_attention_standard, axis_name=self.axis_name), 418 mesh=self.mesh, (...) 426 check_rep=False 427 ) --> 428 attn_output = ring_attention_sharded( 429 query_states, key_states, value_states, attention_mask 430 ) 431 return AttentionOutput( 432 attention_weights=None, 433 attention_outputs=attn_output 434 )
[... skipping hidden 9 frame]
File /usr/local/lib/python3.10/inspect.py:3186, in Signature.bind(self, *args, *kwargs)
3181 def bind(self, /, args, **kwargs):
3182 """Get a BoundArguments object, that maps the passed args
3183 and kwargs
to the function's signature. Raises TypeError
3184 if the passed arguments can not be bound.
3185 """
-> 3186 return self._bind(args, kwargs)
File /usr/local/lib/python3.10/inspect.py:3101, in Signature._bind(self, args, kwargs, partial) 3099 msg = 'missing a required argument: {arg!r}' 3100 msg = msg.format(arg=param.name) -> 3101 raise TypeError(msg) from None 3102 else: 3103 # We have a positional argument to process 3104 try:
TypeError: missing a required argument: 'scale'`
@jcole75 hi, and thank for finding the ring attention. it's fixed now and you can try that but since you are using TPU-v3 i don't recommend using ring attention
you can use sharded_vanilla
which is faster and more efficient
@jcole75 is your issue fixed?
I get an error when trying to run it. I tried upgrading scipy and that didn't help.
`AttributeError Traceback (most recent call last) Cell In[4], line 1 ----> 1 from EasyDel import ( 2 AutoEasyDelModelForCausalLM, 3 TrainArguments, 4 CausalLanguageModelTrainer, 5 EasyDelOptimizers, 6 EasyDelSchedulers, 7 EasyDelGradientCheckPointers, 8 EasyDelState, 9 EasyDeLXRapTureConfig, 10 get_modules_by_type, 11 easystate_to_huggingface_model, 12 SFTTrainer, 13 conversations_formatting_function, 14 AutoEasyDelConfig 15 ) 16 from datasets import load_dataset 17 from flax.core import FrozenDict
File /usr/local/lib/python3.10/site-packages/EasyDel/init.py:1 ----> 1 from .serve import ( 2 EasyServe as EasyServe, 3 EasyServeConfig as EasyServeConfig, 4 LLMBaseReq as LLMBaseReq, 5 GenerateAPIRequest as GenerateAPIRequest, 6 ConversationItem as ConversationItem, 7 ModelOutput as ModelOutput, 8 BaseModel as BaseModel, 9 EasyClient as EasyClient, 10 GradioUserInference as GradioUserInference, 11 ChatRequest as ChatRequest, 12 InstructRequest as InstructRequest, 13 PyTorchServer as PyTorchServer, 14 PyTorchServerConfig as PyTorchServerConfig, 15 JAXServer as JAXServer, 16 JAXServerConfig as JAXServerConfig 17 ) 19 from .modules.llama import ( 20 FlaxLlamaModel as FlaxLlamaModel, 21 FlaxLlamaForCausalLM as FlaxLlamaForCausalLM, (...) 25 VisionLlamaConfig as VisionLlamaConfig 26 ) 27 from .modules.gpt_j import ( 28 GPTJConfig as GPTJConfig, 29 FlaxGPTJForCausalLM as FlaxGPTJForCausalLM, 30 FlaxGPTJModel as FlaxGPTJModel, 31 )
File /usr/local/lib/python3.10/site-packages/EasyDel/serve/init.py:1 ----> 1 from .jax_serve import JAXServer, JAXServerConfig 2 from .torch_serve import PyTorchServer, PyTorchServerConfig 3 from .utils import ChatRequest, InstructRequest
File /usr/local/lib/python3.10/site-packages/EasyDel/serve/jax_serve.py:14 12 import uvicorn 13 from fastapi import FastAPI ---> 14 from fjformer import make_shard_and_gather_fns, match_partition_rules, with_sharding_constraint 16 from ..etils.etils import get_logger 17 from ..smi import get_mem, initialise_tracking
File /usr/local/lib/python3.10/site-packages/fjformer/init.py:68 60 from .utils import ( 61 JaxRNG as JaxRNG, 62 GenerateRNG as GenerateRNG, 63 init_rng as init_rng, 64 next_rng as next_rng, 65 ) 67 from . import pallas_operations as pallas_operations ---> 68 from . import optimizers as optimizers 69 from . import linen as linen 71 version = "0.0.46"
File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/init.py:1 ----> 1 from .adamw import ( 2 get_adamw_with_cosine_scheduler as get_adamw_with_cosine_scheduler, 3 get_adamw_with_warm_up_cosine_scheduler as get_adamw_with_warm_up_cosine_scheduler, 4 get_adamw_with_warmup_linear_scheduler as get_adamw_with_warmup_linear_scheduler, 5 get_adamw_with_linear_scheduler as get_adamw_with_linear_scheduler 6 ) 7 from .lion import ( 8 get_lion_with_cosine_scheduler as get_lion_with_cosine_scheduler, 9 get_lion_with_with_warmup_linear_scheduler as get_lion_with_with_warmup_linear_scheduler, 10 get_lion_with_warm_up_cosine_scheduler as get_lion_with_warm_up_cosine_scheduler, 11 get_lion_with_linear_scheduler as get_lion_with_linear_scheduler 12 ) 13 from .adafactor import ( 14 get_adafactor_with_cosine_scheduler as get_adafactor_with_cosine_scheduler, 15 get_adafactor_with_warm_up_cosine_scheduler as get_adafactor_with_warm_up_cosine_scheduler, 16 get_adafactor_with_warmup_linear_scheduler as get_adafactor_with_warmup_linear_scheduler, 17 get_adafactor_with_linear_scheduler as get_adafactor_with_linear_scheduler 18 )
File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/adamw.py:3 1 from typing import Optional 2 import chex ----> 3 import optax 6 def get_adamw_with_cosine_scheduler( 7 steps: int, 8 learning_rate: float = 5e-5, (...) 16 17 ): 18 """ 19 20 :param gradient_accumulation_steps: (...) 29 :return: Optimizer and Scheduler 30 """
File /usr/local/lib/python3.10/site-packages/optax/init.py:17 1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); (...) 13 # limitations under the License. 14 # ============================================================================== 15 """Optax: composable gradient processing and optimization, in JAX.""" ---> 17 from optax import contrib 18 from optax import losses 19 from optax import monte_carlo
File /usr/local/lib/python3.10/site-packages/optax/contrib/init.py:21 19 from optax.contrib.complex_valued import split_real_and_imaginary 20 from optax.contrib.complex_valued import SplitRealAndImaginaryState ---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw 22 from optax.contrib.dadapt_adamw import DAdaptAdamWState 23 from optax.contrib.mechanic import MechanicState
File /usr/local/lib/python3.10/site-packages/optax/contrib/dadapt_adamw.py:27 25 from optax import tree_utils 26 from optax._src import base ---> 27 from optax._src import utils 30 class DAdaptAdamWState(NamedTuple): 31 """State of the
GradientTransformation
returned bydadapt_adamw
."""File /usr/local/lib/python3.10/site-packages/optax/_src/utils.py:22 20 import jax 21 import jax.numpy as jnp ---> 22 import jax.scipy.stats.norm as multivariate_normal 24 from optax._src import linear_algebra 25 from optax._src import numerics
File /usr/local/lib/python3.10/site-packages/jax/scipy/stats/init.py:40 38 from jax.scipy.stats import gennorm as gennorm 39 from jax.scipy.stats import truncnorm as truncnorm ---> 40 from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde 41 from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata 42 from jax.scipy.stats import vonmises as vonmises
File /usr/local/lib/python3.10/site-packages/jax/_src/scipy/stats/kde.py:26 24 from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps 25 from jax._src.tree_util import register_pytree_node_class ---> 26 from jax.scipy import linalg, special 29 @_wraps(osp_stats.gaussian_kde, update_doc=False) 30 @register_pytree_node_class 31 @dataclass(frozen=True, init=False) 32 class gaussian_kde: 33 neff: Any
File /usr/local/lib/python3.10/site-packages/jax/scipy/linalg.py:18 1 # Copyright 2020 The JAX Authors. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); (...) 15 # Note: import as is required for names to be exported.
16 # See PEP 484 & https://github.com/google/jax/issues/7570
---> 18 from jax._src.scipy.linalg import (
19 block_diag as block_diag,
20 cholesky as cholesky,
21 cho_factor as cho_factor,
22 cho_solve as cho_solve,
23 det as det,
24 eigh as eigh,
25 eigh_tridiagonal as eigh_tridiagonal,
26 expm as expm,
27 expm_frechet as expm_frechet,
28 hessenberg as hessenberg,
29 inv as inv,
30 lu as lu,
31 lu_factor as lu_factor,
32 lu_solve as lu_solve,
33 polar as polar,
34 qr as qr,
35 rsf2csf as rsf2csf,
36 schur as schur,
37 sqrtm as sqrtm,
38 solve as solve,
39 solve_triangular as solve_triangular,
40 svd as svd,
41 toeplitz as toeplitz,
42 )
44 from jax._src.third_party.scipy.linalg import (
45 funm as funm,
46 )
48 # Deprecations
File /usr/local/lib/python3.10/site-packages/jax/_src/scipy/linalg.py:403 399 del overwrite_b, debug, check_finite # unused 400 return _solve_triangular(a, b, trans, lower, unit_diagonal) --> 403 @_wraps(scipy.linalg.tril) 404 def tril(m: ArrayLike, k: int = 0) -> Array: 405 return jnp.tril(m, k) 408 @_wraps(scipy.linalg.triu) 409 def triu(m: ArrayLike, k: int = 0) -> Array:
AttributeError: module 'scipy.linalg' has no attribute 'tril'`