descriptinc / audiotools

Object-oriented handling of audio data, with GPU-powered augmentations, and more.
https://descriptinc.github.io/audiotools/
MIT License
233 stars 39 forks source link

Adding datasets #21

Closed pseeth closed 2 years ago

pseeth commented 2 years ago

This PR adds the base classes needed to implement the RoomSimulator. The BaseDataset class is ripped from the wav2wav code and enables an important feature: shared attributes across processes. Here's the situation. Let's say you're trying to make it so that after 100 epochs, you want the duration to double. Ideally you'd write something like this:

dataset = ...
for epoch in range(num_epochs):
  if epoch == 100:
    dataset.duration = 2 * dataset.duration

This works fine if you're working with a single process. However, an issue occurs if you use a dataloader:

dataset = ...
dataloader = torch.utils.data.DataLoader(
  dataset,
  batch_size=16,
  num_workers=2
)
for epoch in range(num_epochs):
  if epoch == 100:
    dataloader.dataset.duration = 2 * dataloader.dataset.duration

That last line doesn't work. It only accesses the duration object of the original dataset, which is only on the very first data worker. All the other workers get a copy of the original dataset, so manipulating the dataset attribute of the first worker doesn't propagate to the other workers.

To fix this, this PR introduces the BaseDataset class, which uses the Manager object from the multiprocessing module to handle it. Certain attributes for the base dataset are marked as shared here:

https://github.com/descriptinc/lyrebird-audiotools/blob/2f98541721afb226600ea2f15df4a5e3ce5e4faa/audiotools/data/datasets.py#L16-L18

This makes the snippet above work. Specifically doing dataloader.dataset.duration = 2 * dataloader.dataset.duration will now double the duration on all data workers.

There's a catch - transforms

The Manager object only detects and communicates changes to all processes in a shallow way. That is to say, if you change the attribute directly (like changing the duration), it knows about that. But let's say the value of the attribute is not just a float or int, but rather an object, like the transform. This won't work:

dataloader.dataset.transform.some_transform_property = new_value

This change doesn't get propagated! The Manager object doesn't know that the attribute has changed in value. So instead you just should rebuild the transform, and re-assign it, like this:

new_transform = Transform(new_value)
dataloader.dataset.transform = new_transform

And this propagates.

Other changes

Added the CSVDataset, which is basically the RoomSimulator, but a bit more flexible. It takes a list of CSV files describing audio.

Moved the collate function to be a static method of BaseDataset (and all datasets that derive from it), so that when building a DataLoader, you can just do collate_fn=dataset.collate.

Note: This PR is made to ps/transforms right now, but once that's merged, I'll switch its target to master.