facebookresearch / fairseq2

FAIR Sequence Modeling Toolkit 2
https://facebookresearch.github.io/fairseq2/
MIT License
698 stars 83 forks source link

Audio Data Loader Example? #120

Open cinjon opened 1 year ago

cinjon commented 1 year ago

Hi, I'd like to use FairSeq2 for my audio models. Is there an example of an audio data loader in the works? That would be super helpful.

What I need in the dataloader:

  1. I have lots of audio that I'd like to load. This is currently in mp3, but I've been cutting chunks out as tuples of (audio: Wav, labels: Json).
  2. This data is hosted on R2 and I'd like the loader to stream it via an endpoint_url.
  3. I also would like for this to save to local cache as it's a big dataset.
  4. Distribute amongst cpu workers.

I can accommodate other data saving formats as needed, e.g. a flat zip, but the most important thing is that it's fast. I'm right now I/O bottlenecked.

Thanks!

cbalioglu commented 1 year ago

Hey @cinjon, is there any chance for me to check out an example dataset and the script you use to load it?

cinjon commented 1 year ago

Sure, I rewrote it last night to use the Mosaic Streamer. This is what it looks like at the moment. I am using an ordinary DataLoader from pytorch with num_workers=11 and pin_memory=True on a 12-core machine and getting ~10s time to build a batch of 256.

### Stream Reader:
class MyStreamingDataset(streaming.StreamingDataset):
  def __init__(self, local, remote, shuffle):
    super().__init__(local=local, remote=remote, shuffle=shuffle)

  def __getitem__(self, idx: int) -> Any:
    # columns = {
    #     'start_time': 'float32',
    #     'key': 'str',
    #     'end_time': 'float32',
    #     'label': 'int8',
    #     'wav': 'bytes',
    # }
    obj = super().__getitem__(idx)

    end_time = obj['end_time']
    start_time = obj['start_time']
    label = obj['label']    
    wav = io.BytesIO(obj['wav'])

    # window_in_secs = 5, so there's buffer in the loaded 6 second example.
    relative_start_time = end_time - window_in_secs - start_time
    if label:
      # Do a positive sample, can only use a small part of the sample.
      max_reduction = min(relative_start_time, predict_secs)
      this_start_time = relative_start_time - max_reduction * random.random()
      offset = int(target_sr * this_start_time)
      label = torch.tensor(1, dtype=torch.int64)
    else:
      # Do a negative sample. Here, the entire sample is fair game.
      max_reduction = relative_start_time
      this_start_time = random.random() * relative_start_time
      offset = int(target_sr * this_start_time)
      label = torch.tensor(0, dtype=torch.int64)

    num_frames = window_in_secs * target_sr
    # NOTE: This loading step takes .01 seconds by itself :(
    wav, sr = torchaudio.load(wav, frame_offset=offset, num_frames=num_frames)
    wav = wav.mean(axis=0, keepdims=True)
    return wav, label

### Stream Writer:
def get_examples(...):
    ...
    window = 6 # seconds
    window_in_samples = window * sr
    for num_snippet, snippet in enumerate(snippets):
        t = snippet["time"]
        if t < window:
            continue

        start_time = t - window
        end_time = t
        start_frames = int(start_time * sr)
        end_frames = start_frames + window_in_samples

        sub_wav = wav[start_frames:end_frames]
        assert sub_wav.shape[0] == window_in_samples
        scipy.io.wavfile.write(basename_wav, sr, sub_wav)
        labels = {
            'end_time': t,
            'start_time': start_time,
            'key': key,
            'label': 1
        }
        examples.append((sub_wav, labels))

all_examples = get_examples()
random.shuffle(all_examples)
with MDSWriter(out=mds_directory, columns=columns, compression=compression) as out:            
    for wav, labels in all_examples:
        sample = labels.copy()
        sample['wav'] = wav
        out.write(sample)