Action : Sharding Passed Parameters
Model Contain 1.100048384 Billion Parameters
0%| | 0/1500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 101
93 # you can do the same for evaluation process dataset
95 trainer = CausalLanguageModelTrainer(
96 train_arguments,
97 dataset_train,
98 checkpoint_path=None
99 )
--> 101 output = trainer.train(flax.core.FrozenDict({"params": params}))
102 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:509, in CausalLanguageModelTrainer.train(self, model_parameters, state)
507 try:
508 for epoch in range(self.arguments.num_train_epochs):
--> 509 for batch in self.dataloader_train:
510 current_step += 1
511 if (
512 self.arguments.step_start_point is not None
513 and
514 self.arguments.step_start_point > current_step
515 ):
File /usr/local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO([https://github.com/pytorch/pytorch/issues/76750)](https://github.com/pytorch/pytorch/issues/76750)%3C/span%3E)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File /usr/local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
672 def _next_data(self):
673 index = self._next_index() # may raise StopIteration
--> 674 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
675 if self._pin_memory:
676 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File /usr/local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
52 else:
53 data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:179, in CausalLanguageModelTrainer.create_collate_function.<locals>.collate_fn(batch)
175 else:
176 corrected_sequence = [
177 jnp.array(f[key])[..., :max_sequence_length] for f in batch
178 ]
--> 179 results[key] = jnp.stack(corrected_sequence).reshape(
180 -1,
181 corrected_sequence[0].shape[-1]
182 )
183 return results
File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1796, in stack(arrays, axis, out, dtype)
1794 for a in arrays:
1795 if shape(a) != shape0:
-> 1796 raise ValueError("All input arrays must have the same shape.")
1797 new_arrays.append(expand_dims(a, axis))
1798 return concatenate(new_arrays, axis=axis, dtype=dtype)
ValueError: All input arrays must have the same shape
Describe the bug
To Reproduce Install dependencies
Run the example on kaggle using TPUs