lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.36k stars 255 forks source link

Allow passing custom DataLoaders to SoundStreamTrainer #145

Closed hmartiro closed 1 year ago

hmartiro commented 1 year ago

This allows customizing the data loading from the default provided by SoundDataset, which is a folder of audio files on disk. For example, it can allow someone to connect a compatible WebDataset loader instead.

Also adds types to SoundStreamTrainer's constructor. I can revert this if you don't like it.

hmartiro commented 1 year ago

I think relevant for #138 @marianna13

For testing sake here's how I externally created the same loader:

from audiolm_pytorch.data import SoundDataset, pad_to_longest_fn
from torch.utils.data import DataLoader

dataset = SoundDataset(
    folder,
    max_length=int(data_max_length_seconds * soundstream.target_sample_hz),
    target_sample_hz=soundstream.target_sample_hz,
    seq_len_multiple_of=soundstream.seq_len_multiple_of,
)

train_dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    collate_fn=pad_to_longest_fn,
    num_workers=num_workers,
    shuffle=True,
    **kwargs,
)
hmartiro commented 1 year ago

I might also suggest considering a broader refactor where dataloaders are always constructed externally and passed in to trainers and it removes the data-related arguments from the trainer. It will make the simple case more code, but could help in understanding.

lucidrains commented 1 year ago

yes, this looks great! 🙏