erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Example shown on https://pypi.org/project/EasyDeL/ to finetune tinyllama raise exception on kaggle #105

Closed jchauhan closed 4 months ago

jchauhan commented 4 months ago

Describe the bug

Action : Sharding Passed Parameters
Model Contain 1.100048384 Billion Parameters
  0%|          | 0/1500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 101
     93 # you can do the same for evaluation process dataset
     95 trainer = CausalLanguageModelTrainer(
     96     train_arguments,
     97     dataset_train,
     98     checkpoint_path=None
     99 )
--> 101 output = trainer.train(flax.core.FrozenDict({"params": params}))
    102 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:509, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    507 try:
    508     for epoch in range(self.arguments.num_train_epochs):
--> 509         for batch in self.dataloader_train:
    510             current_step += 1
    511             if (
    512                     self.arguments.step_start_point is not None
    513                     and
    514                     self.arguments.step_start_point > current_step
    515             ):

File /usr/local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO([https://github.com/pytorch/pytorch/issues/76750)](https://github.com/pytorch/pytorch/issues/76750)%3C/span%3E)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
    672 def _next_data(self):
    673     index = self._next_index()  # may raise StopIteration
--> 674     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    675     if self._pin_memory:
    676         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /usr/local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     52 else:
     53     data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:179, in CausalLanguageModelTrainer.create_collate_function.<locals>.collate_fn(batch)
    175     else:
    176         corrected_sequence = [
    177             jnp.array(f[key])[..., :max_sequence_length] for f in batch
    178         ]
--> 179     results[key] = jnp.stack(corrected_sequence).reshape(
    180         -1,
    181         corrected_sequence[0].shape[-1]
    182     )
    183 return results

File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1796, in stack(arrays, axis, out, dtype)
   1794 for a in arrays:
   1795   if shape(a) != shape0:
-> 1796     raise ValueError("All input arrays must have the same shape.")
   1797   new_arrays.append(expand_dims(a, axis))
   1798 return concatenate(new_arrays, axis=axis, dtype=dtype)

ValueError: All input arrays must have the same shape

To Reproduce Install dependencies

# !pip install git+https://github.com/erfanzar/EasyDeL.git
!pip install EasyDeL==0.0.50
!pip install sentencepiece
!pip install jaxlib==0.4.19
!pip install jax[tpu]==0.4.19 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 
!apt-get update && apt-get upgrade -y
!apt-get install golang -y 

Run the example on kaggle using TPUs

jchauhan commented 4 months ago

@erfanzar Any progress on it?

erfanzar commented 4 months ago

can you try that again?

erfanzar commented 4 months ago

This issue is being closed due to being same as issue number #104