erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
191 stars 23 forks source link

How to continue training from a previous saved easydel checkpoint? #126

Closed IvoryTower800 closed 6 months ago

IvoryTower800 commented 6 months ago

Describe the bug I saved a checkpoint "writer_gemma_2b_it-S51.easy". Could you please guide me how can I continue training from this checkpoint? In addition, should I add save_optimizer_state=True in TrainArguments? Thank you!

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path='/kaggle/working/EasyDel-Checkpoints/writer_gemma_2b_it/writer_gemma_2b_it-S51.easy'
)
output = trainer.train()

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[32], line 7
      1 trainer = CausalLanguageModelTrainer(
      2     train_arguments,
      3     dataset_train,
      4     checkpoint_path='/kaggle/working/EasyDel-Checkpoints/writer_gemma_2b_it/writer_gemma_2b_it-S51.easy'
      5 )
----> 7 output = trainer.train()

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:451, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    449     initialise_tracking(dir_prefix=dir_prefix)
    450 start_time = time.time()
--> 451 sharded_state, shard_fns, gather_fns = self.initialize_state(
    452     model_parameters=model_parameters,
    453     state=state
    454 )
    456 count_model_parameters(sharded_state.params)
    457 with self.mesh:

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:304, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
    298 def initialize_state(
    299         self,
    300         model_parameters: Optional[flax.core.FrozenDict] = None,
    301         state: Optional[EasyDelState] = None,
    302 ) -> Tuple[EasyDelState, Mapping[str, Callable], Mapping[str, Callable]]:
    303     if model_parameters is None and state is None and self.rapture is None:
--> 304         raise RuntimeError(
    305             "You are passing `model_parameters=None` and `state=None` and also you are not using LoRA if you are "
    306             "Using LoRA make sure to pass parameters and Rapture Config correctly otherwise pass the "
    307             "model_parameters or state."
    308         )
    309     if model_parameters is None and state is None:
    310         model_parameters = self.lora_parameters

RuntimeError: You are passing `model_parameters=None` and `state=None` and also you are not using LoRA if you are Using LoRA make sure to pass parameters and Rapture Config correctly otherwise pass the model_parameters or state.

Then I add model_parameters. it says:

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path='/kaggle/working/EasyDel-Checkpoints/writer_gemma_2b_it/writer_gemma_2b_it-S51.easy'
)

output = trainer.train(flax.core.FrozenDict({"params": params}))

---------------------------------------------------------------------------
EasyDelTimerError                         Traceback (most recent call last)
Cell In[43], line 7
      1 trainer = CausalLanguageModelTrainer(
      2     train_arguments,
      3     dataset_train,
      4     checkpoint_path='/kaggle/working/EasyDel-Checkpoints/writer_gemma_2b_it/writer_gemma_2b_it-S51.easy'
      5 )
----> 7 output = trainer.train(flax.core.FrozenDict({"params": params}))

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:451, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    449     initialise_tracking(dir_prefix=dir_prefix)
    450 start_time = time.time()
--> 451 sharded_state, shard_fns, gather_fns = self.initialize_state(
    452     model_parameters=model_parameters,
    453     state=state
    454 )
    456 count_model_parameters(sharded_state.params)
    457 with self.mesh:

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:372, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
    370     sharded_state = self.create_sharded_state_from_params_function(model_parameters)
    371 elif model_parameters is not None and self.checkpoint_path is not None:
--> 372     raise EasyDelTimerError(
    373         "You can't pass `model_parameters` and `checkpoint_path` at same time"
    374     )
    375 else:
    376     raise EasyDelTimerError(
    377         "You should pass `model_parameters` or `checkpoint_path` to trainer in order to load model"
    378     )

EasyDelTimerError: You can't pass `model_parameters` and `checkpoint_path` at same time
erfanzar commented 6 months ago

Hi, and thanks for using EasyDeL! can you try that again? it's fixed.

IvoryTower800 commented 6 months ago

@erfanzar Thank you. I tried the latest version. it can load the checkpoint now. but it still cannot continue trainning.

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path='/kaggle/working/EasyDel-Checkpoints/writer_gemma_2b_it/writer_gemma_2b_it-S51.easy'
)

output = trainer.train()
print(f"Hey ! , here's where your model saved {output.checkpoint_path}") 

Time Took to Complete Task configure dataloaders (microseconds) : 0.2803802490234375
Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 1699.0447044372559
Time Took to Complete Task configure functions and sharding them (microseconds) : 1788.7015342712402
Action : Loading Model From /kaggle/working/EasyDel-Checkpoints/writer_gemma_2b_it/writer_gemma_2b_it-S51.easy
Loading Checkpoints From File: 177it [00:41,  4.27it/s, shard_functions_mismatch=2]
Couldn't load past optimizer State initializing new one with default Optimizer and Scheduler
Model Contain 2.506172416 Billion Parameters
  0%|          | 0/10000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[1], line 157
    145 dataset_train = dataset_train.map(
    146     tokenization_process,
    147     num_proc=os.cpu_count(),
    148     remove_columns=dataset_train.column_names
    149 )
    151 trainer = CausalLanguageModelTrainer(
    152     train_arguments,
    153     dataset_train,
    154     checkpoint_path='/kaggle/working/EasyDel-Checkpoints/writer_gemma_2b_it/writer_gemma_2b_it-S51.easy'
    155 )
--> 157 output = trainer.train()
    158 print(f"Hey ! , here's where your model saved {output.checkpoint_path}") 

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:536, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    534 for ssb in self.arguments.ids_to_pop_from_dataset:
    535     _ = batch.pop(ssb, None)
--> 536 sharded_state, loss, accuracy = self.sharded_train_step_function(
    537     sharded_state,
    538     batch
    539 )
    540 loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss
    541 accuracy_sum = accuracy.tolist() if accuracy_sum is None else accuracy_sum + accuracy

    [... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:952, in _prefix_error.<locals>.<lambda>(name)
    945   full_tree_meta_str = str(full_tree_meta)
    946   metadata_diff = textwrap.indent(
    947       "\n".join(
    948           difflib.ndiff(prefix_tree_meta_str.splitlines(),
    949                         full_tree_meta_str.splitlines())),
    950       prefix="    ")
    951   yield lambda name: ValueError(
--> 952     "pytree structure error: different pytree metadata at key path\n"
    953     f"    {{name}}{keystr(key_path)}\n"
    954     f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
    955     f"    {type(prefix_tree)}\n"
    956     f"with metadata\n"
    957     f"    {prefix_tree_meta_str}\n"
    958     f"but at the same key path the full pytree has a subtree of the same "
    959     f"type but with metadata\n"
    960     f"    {full_tree_meta_str}\n"
    961     f"so the diff in the metadata at these pytree nodes is\n"
    962     f"{metadata_diff}".format(name=name))
    963   return  # don't look for more errors in this subtree
    965 # If the root types and numbers of children agree, there must be an error
    966 # in a subtree, so recurse:

KeyError: '\n\t  "bos_token_id"'

In addition. it says "Couldn't load past optimizer State initializing new one with default Optimizer and Scheduler". Does that mean i need to set save_optimizer_state=True in TrainArguments if I want to load past optimizer state?

IvoryTower800 commented 6 months ago

Hi, I may found another issue. I tried to convert the saved easy file to huggingface model through "easystate_to_huggingface_model".

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=EasyDelState.load_state(
            output.checkpoint_path
        ),
        base_huggingface_module=PhiForCausalLM,
        config=model.config
    )

model.push_to_hub(repo_id="writer-phi-2" )
tokenizer.push_to_hub(repo_id="writer-phi-2")
Downloading shards: 100%|██████████| 2/2 [01:52<00:00, 56.44s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 51
     44 huggingface_repo_id_or_path = "ivt1993/writer-phi-2"
     46 tokenizer = AutoTokenizer.from_pretrained(
     47     'microsoft/phi-2',
     48     trust_remote_code=True
     49 )
---> 51 model, params = AutoEasyDelModelForCausalLM.from_pretrained(
     52     huggingface_repo_id_or_path,
     53 )
     54 model.config.rope_scaling =   {
     55     "factor": 4.0,
     56     "type": "linear"
     57   }
     59 max_length = 8192

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/auto_easydel_model.py:344, in AutoEasyDelModelForCausalLM.from_pretrained(cls, pretrained_model_name_or_path, device, dtype, param_dtype, precision, sharding_axis_dims, sharding_axis_names, query_partition_spec, key_partition_spec, value_partition_spec, bias_partition_spec, attention_partition_spec, use_shard_map, input_shape, shard_fns, backend, config_kwargs, **kwargs)
    341 cfg, module, trf = get_modules_by_type(model_type)
    343 logger.debug(f"Downloading model weights from {pretrained_model_name_or_path}")
--> 344 model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
    345 cfg = cfg.from_pretrained(pretrained_model_name_or_path)
    346 state_dict = model.state_dict()

File /usr/local/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:561, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    559 elif type(config) in cls._model_mapping.keys():
    560     model_class = _get_model_class(config, cls._model_mapping)
--> 561     return model_class.from_pretrained(
    562         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    563     )
    564 raise ValueError(
    565     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    566     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    567 )

File /usr/local/lib/python3.10/site-packages/transformers/modeling_utils.py:3502, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3493     if dtype_orig is not None:
   3494         torch.set_default_dtype(dtype_orig)
   3495     (
   3496         model,
   3497         missing_keys,
   3498         unexpected_keys,
   3499         mismatched_keys,
   3500         offload_index,
   3501         error_msgs,
-> 3502     ) = cls._load_pretrained_model(
   3503         model,
   3504         state_dict,
   3505         loaded_state_dict_keys,  # XXX: rename?
   3506         resolved_archive_file,
   3507         pretrained_model_name_or_path,
   3508         ignore_mismatched_sizes=ignore_mismatched_sizes,
   3509         sharded_metadata=sharded_metadata,
   3510         _fast_init=_fast_init,
   3511         low_cpu_mem_usage=low_cpu_mem_usage,
   3512         device_map=device_map,
   3513         offload_folder=offload_folder,
   3514         offload_state_dict=offload_state_dict,
   3515         dtype=torch_dtype,
   3516         hf_quantizer=hf_quantizer,
   3517         keep_in_fp32_modules=keep_in_fp32_modules,
   3518     )
   3520 # make sure token embedding weights are still tied if needed
   3521 model.tie_weights()

File /usr/local/lib/python3.10/site-packages/transformers/modeling_utils.py:3977, in PreTrainedModel._load_pretrained_model(***failed resolving arguments***)
   3973     if "size mismatch" in error_msg:
   3974         error_msg += (
   3975             "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
   3976         )
-> 3977     raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
   3979 if len(unexpected_keys) > 0:
   3980     archs = [] if model.config.architectures is None else model.config.architectures

RuntimeError: Error(s) in loading state_dict for PhiForCausalLM:
    size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([51200, 2560]) from checkpoint, the shape in current model is torch.Size([51200, 2048]).
    size mismatch for model.layers.0.self_attn.q_proj.weight: copying a param with shape torch.Size([2560, 2560]) from checkpoint, the shape in current model is torch.Size([2048, 2048]).
    size mismatch for model.layers.0.self_attn.q_proj.bias: copying a param with shape torch.Size([2560]) from checkpoint, the shape in current model is torch.Size([2048]).
    size mismatch for model.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([2560, 2560]) from checkpoint, the shape in current model is torch.Size([2048, 2048]).
    size mismatch for model.layers.0.self_attn.k_proj.bias: copying a param with shape torch.Size([2560]) from checkpoint, the shape in current model is torch.Size([2048]).
    size mismatch for model.layers.0.self_attn.v_proj.weight: copying a param with shape torch.Size([2560, 2560]) from checkpoint, the shape in current model is torch.Size([2048, 2048]).
    size mismatch for model.layers.0.self_attn.v_proj.bias: copying a param with shape torch.Size([2560]) from checkpoint, the shape in current model is torch.Size([2048]).

However, when I load it again. the error raised. It's Phi-2 model.

erfanzar commented 6 months ago

hello that's a huggingface bug actually i faced the same issue and that's not related to EasyDeL but I can tell you how to fix that. you have 2 options

  1. Give EasyDeL model creator the new model config and don't use the default one (e.g AutoEasyDelConfig)
  2. change the Phi Model config in huggingface -> model tab -> config.json and paste this code
{
  "_name_or_path": "microsoft/phi-2",
  "architectures": [
    "PhiForCausalLM"
  ],
  "auto_map": {
    "AutoConfig": "configuration_phi.PhiConfig",
    "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM"
  },
  "attention_dropout": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "hidden_act": "gelu_new",
  "hidden_size": 2560,
  "initializer_range": 0.02,
  "intermediate_size": 10240,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "phi",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "partial_rotary_factor": 0.4,
  "qk_layernorm": false,
  "resid_pdrop": 0.1,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.37.0",
  "use_cache": true,
  "vocab_size": 51200
}
IvoryTower800 commented 6 months ago

@erfanzar Thank you so much! after modified the config.json on huggingface, the model can be loaded correctedly.

How about the KeyError: '\n\t "bos_token_id"' issue? I tried gemma and phi-2 models, both of them raise this error.

erfanzar commented 6 months ago

you found a bug in EasyDeL;\ I have never tried that feature so there's a mismatch memory pointer error from jax.pjit that happens because I have re-initialized the module twice and the passed model is not the compiled one so this error will be raised. I'm fixing that right now and thank you for reporting that <3.

erfanzar commented 6 months ago

hello, I guess you can try that again, it's fixed and the trainer has passed all of the tests.

IvoryTower800 commented 6 months ago

Hi, I tried that again. yes, it works now. I can continue training from the easy file.

erfanzar commented 6 months ago

nice. if there's any other issue please let me know