facebookresearch / fairseq2

FAIR Sequence Modeling Toolkit 2
https://facebookresearch.github.io/fairseq2/
MIT License
682 stars 78 forks source link

Audio Data Loader Example? #120

Open cinjon opened 11 months ago

cinjon commented 11 months ago

Hi, I'd like to use FairSeq2 for my audio models. Is there an example of an audio data loader in the works? That would be super helpful.

What I need in the dataloader:

  1. I have lots of audio that I'd like to load. This is currently in mp3, but I've been cutting chunks out as tuples of (audio: Wav, labels: Json).
  2. This data is hosted on R2 and I'd like the loader to stream it via an endpoint_url.
  3. I also would like for this to save to local cache as it's a big dataset.
  4. Distribute amongst cpu workers.

I can accommodate other data saving formats as needed, e.g. a flat zip, but the most important thing is that it's fast. I'm right now I/O bottlenecked.

Thanks!

cbalioglu commented 11 months ago

Hey @cinjon, is there any chance for me to check out an example dataset and the script you use to load it?

cinjon commented 11 months ago

Sure, I rewrote it last night to use the Mosaic Streamer. This is what it looks like at the moment. I am using an ordinary DataLoader from pytorch with num_workers=11 and pin_memory=True on a 12-core machine and getting ~10s time to build a batch of 256.

### Stream Reader:
class MyStreamingDataset(streaming.StreamingDataset):
  def __init__(self, local, remote, shuffle):
    super().__init__(local=local, remote=remote, shuffle=shuffle)

  def __getitem__(self, idx: int) -> Any:
    # columns = {
    #     'start_time': 'float32',
    #     'key': 'str',
    #     'end_time': 'float32',
    #     'label': 'int8',
    #     'wav': 'bytes',
    # }
    obj = super().__getitem__(idx)

    end_time = obj['end_time']
    start_time = obj['start_time']
    label = obj['label']    
    wav = io.BytesIO(obj['wav'])

    # window_in_secs = 5, so there's buffer in the loaded 6 second example.
    relative_start_time = end_time - window_in_secs - start_time
    if label:
      # Do a positive sample, can only use a small part of the sample.
      max_reduction = min(relative_start_time, predict_secs)
      this_start_time = relative_start_time - max_reduction * random.random()
      offset = int(target_sr * this_start_time)
      label = torch.tensor(1, dtype=torch.int64)
    else:
      # Do a negative sample. Here, the entire sample is fair game.
      max_reduction = relative_start_time
      this_start_time = random.random() * relative_start_time
      offset = int(target_sr * this_start_time)
      label = torch.tensor(0, dtype=torch.int64)

    num_frames = window_in_secs * target_sr
    # NOTE: This loading step takes .01 seconds by itself :(
    wav, sr = torchaudio.load(wav, frame_offset=offset, num_frames=num_frames)
    wav = wav.mean(axis=0, keepdims=True)
    return wav, label

### Stream Writer:
def get_examples(...):
    ...
    window = 6 # seconds
    window_in_samples = window * sr
    for num_snippet, snippet in enumerate(snippets):
        t = snippet["time"]
        if t < window:
            continue

        start_time = t - window
        end_time = t
        start_frames = int(start_time * sr)
        end_frames = start_frames + window_in_samples

        sub_wav = wav[start_frames:end_frames]
        assert sub_wav.shape[0] == window_in_samples
        scipy.io.wavfile.write(basename_wav, sr, sub_wav)
        labels = {
            'end_time': t,
            'start_time': start_time,
            'key': key,
            'label': 1
        }
        examples.append((sub_wav, labels))

all_examples = get_examples()
random.shuffle(all_examples)
with MDSWriter(out=mds_directory, columns=columns, compression=compression) as out:            
    for wav, labels in all_examples:
        sample = labels.copy()
        sample['wav'] = wav
        out.write(sample)