huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
18.98k stars 2.62k forks source link

How to set_epoch with interleave_datasets? #7051

Closed jonathanasdf closed 1 month ago

jonathanasdf commented 1 month ago

Let's say I have dataset A which has 100k examples, and dataset B which has 100m examples.

I want to train on an interleaved dataset of A+B, with stopping_strategy='all_exhausted' so dataset B doesn't repeat any examples. But every time A is exhausted I want it to be reshuffled (eg. calling set_epoch)

Of course I want to interleave as IterableDatasets / streaming mode so B doesn't have to get tokenized completely at the start.

How could I achieve this? I was thinking something like, if I wrap dataset A in some new IterableDataset with from_generator() and manually call set_epoch before interleaving it? But I'm not sure how to keep the number of shards in that dataset...

Something like

dataset_a = load_dataset(...)
dataset_b = load_dataset(...)

def epoch_shuffled_dataset(ds):
  # How to make this maintain the number of shards in ds??
  for epoch in itertools.count():
    ds.set_epoch(epoch)
    yield from iter(ds)

shuffled_dataset_a = IterableDataset.from_generator(epoch_shuffled_dataset, gen_kwargs={'ds': dataset_a})
interleaved = interleave_datasets([shuffled_dataset_a, dataset_b], probs, stopping_strategy='all_exhausted')
lhoestq commented 1 month ago

This is not possible right now afaik :/

Maybe we could have something like this ? wdyt ?


ds = interleave_datasets(
    [shuffled_dataset_a, dataset_b],
    probabilities=probabilities,
    stopping_strategy='all_exhausted',
    reshuffle_each_iteration=True,
)
jonathanasdf commented 1 month ago

That would be helpful for this case!

If there was some way for from_generator to iterate over just a single shard of some dataset that would probably be more ideal. Maybe something like

def from_dataset_generator(dataset, generator_fn, gen_kwargs):
  # calls generator_fn(dataset=dataset_shard, **gen_kwargs)

Another transform I was trying to implement is an input bucketing transform. Essentially you need to iterate through a dataset and reorder the examples in them, which is not really possible with a map() call. But using from_generator() causes the final dataset to be a single shard and loses speed gains from multiple dataloader workers

lhoestq commented 1 month ago

I see, there are some internal functions to get a single shard already but the public .shard() method hasn't been implemented yet for IterableDataset :/

(see the use of ex_iterable.shard_data_sources in IterableDataset._prepare_ex_iterable_for_iteration for example)

jonathanasdf commented 1 month ago

Would that be something planned on the roadmap for the near future, or do you suggest hacking through with internal APIs for now?

jonathanasdf commented 1 month ago

Ok this turned out to be not too difficult. Are there any obvious issues with my implementation?

class ShuffleEveryEpochIterable(iterable_dataset._BaseExamplesIterable):
  """ExamplesIterable that reshuffles the dataset every epoch."""

  def __init__(
    self,
    ex_iterable: iterable_dataset._BaseExamplesIterable,
    generator: np.random.Generator,
  ):
    """Constructor."""
    super().__init__()
    self.ex_iterable = ex_iterable
    self.generator = generator

  def _init_state_dict(self) -> dict:
    self._state_dict = {
      'ex_iterable': self.ex_iterable._init_state_dict(),
      'epoch': 0,
    }
    return self._state_dict

  @typing.override
  def __iter__(self):
    epoch = self._state_dict['epoch'] if self._state_dict else 0
    for i in itertools.count(epoch):
      # Create effective seed using i (subtract in order to avoir overflow in long_scalars)
      effective_seed = copy.deepcopy(self.generator).integers(0, 1 << 63) - i
      effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed
      generator = np.random.default_rng(effective_seed)
      self.ex_iterable = self.ex_iterable.shuffle_data_sources(generator)
      if self._state_dict:
        self._state_dict['epoch'] = i
        self._state_dict['ex_iterable'] = self.ex_iterable._init_state_dict()
      it = iter(self.ex_iterable)
      yield from it

  @typing.override
  def shuffle_data_sources(self, generator):
    ex_iterable = self.ex_iterable.shuffle_data_sources(generator)
    return ShuffleEveryEpochIterable(ex_iterable, generator=generator)

  @typing.override
  def shard_data_sources(self, worker_id: int, num_workers: int):
    ex_iterable = self.ex_iterable.shard_data_sources(worker_id, num_workers)
    return ShuffleEveryEpochIterable(ex_iterable, generator=self.generator)

  @typing.override
  @property
  def n_shards(self) -> int:
    return self.ex_iterable.n_shards

generator = np.random.default_rng(seed)
shuffling = iterable_dataset.ShufflingConfig(generator=generator, _original_seed=seed)
ex_iterable = iterable_dataset.BufferShuffledExamplesIterable(
  dataset._ex_iterable, buffer_size=buffer_size, generator=generator
)
ex_iterable = ShuffleEveryEpochIterable(ex_iterable, generator=generator)
dataset = datasets.IterableDataset(
  ex_iterable=ex_iterable,
  info=dataset._info.copy(),
  split=dataset._split,
  formatting=dataset._formatting,
  shuffling=shuffling,
  distributed=copy.deepcopy(dataset._distributed),
  token_per_repo_id=dataset._token_per_repo_id,
)
lhoestq commented 1 month ago

Nice ! This iterable is infinite though no ? How would interleave_dataset know when to stop ?

Maybe the re-shuffling can be implemented directly in RandomlyCyclingMultiSourcesExamplesIterable (which is the iterable used by interleave_dataset) ?

jonathanasdf commented 1 month ago

Infinite is fine for my usecases fortunately.