lhotse-speech / lhotse

Tools for handling speech data in machine learning projects.
https://lhotse.readthedocs.io/en/latest/
Apache License 2.0
947 stars 218 forks source link

Tutorial review: Using Lhotse with PyTorch Lightning #1086

Open fauxneticien opened 1 year ago

fauxneticien commented 1 year ago

Hi all,

Following up on the tutorials thread here, I've written a first draft of a tutorial for using Lhotse with PyTorch Lightning here: https://colab.research.google.com/drive/13VzQYEUJGcbQEU6L48jsNU3r46ASwGwp?usp=sharing

The notebook can also be found in the notebooks folder in https://github.com/fauxneticien/lnl-examples, which has the associated file train_mwe.py . You can see from this W&B run that I've confirmed the code runs on single/multi-GPU, with and without AMP.

This is just a start with a very basic minimal working example to help newcomers get oriented. I'm planning on extending the MWE with at least 2 more tutorials demoing more realistic usage (e.g. global CMVN on features, SpecAugment, etc.). Before that I was hoping to get feedback/thoughts on the first tutorial and MWE. Any notes on explanations of Lhotse's terms would be very welcome.

I also have one question about the Lhotse API. It seems that BucketingSampler doesn't appear to implement len(), although from looking around the repo a bit at some point in the past it might have? I ask this because the default Lightning progress bar isn't able to properly fetch the number of iterations per epoch. Lightning does allow you to configure the progress bar at each epoch though a callback (see snippet below). Is calling len([ s for s in sampler ]) the way to get this number? I don't know the internals enough to foresee whether looping through the sampler might put the sampler and loader out of sync (and whether DDP might complain).

from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar

class LhotseCompatibleProgressBar(TQDMProgressBar):
   def init_train_tqdm(self):
      bar = super().init_train_tqdm()
      bar.total = len([ s for s in self.trainer.train_dataloader.sampler ])
      return bar

bar = LhotseCompatibleProgressBar()
trainer = Trainer(callbacks=[bar])

Thanks! Nay

pzelasko commented 1 year ago

Thanks, it looks good! As a minor technical note, I don't think you need any of <bos>/<eos>/<pad> (and actually even <unk>) for CTC models.

Regarding len(sampler), we did support it early on, but it was a bad design decision that we reverted. Because of dynamic batch size, computing len requires iteration over full dataset, which is undesirable for larger datasets. The sampler would also have a different len depending on random seed / epoch. You can instead use sampler.remaining_duration to try and approximate the progress (assuming the cutset is not lazy; if it is, it'd just return None).

mrezasoltani commented 1 year ago

I am using the above solution along with DynamicBucketingSampler(), but I still get error (complaining the __len__ method method in Pytorch-Lightning:

class LhotseCompatibleProgressBar(TQDMProgressBar):

def init_train_tqdm(self):
    bar = super().init_train_tqdm()
    bar.total = len([s for s in self.trainer.train_dataloader.sampler])
    return bar

def init_validation_tqdm(self):
    bar = super().init_validation_tqdm()
    bar.total = len([s for s in self.trainer.test_dataloader.sampler])
    return bar

The example in the notebook is based on BucketingSampler which does not support working with lazy CutSet.

Any solution ?

Thanks

fauxneticien commented 1 year ago

Hi @mrezasoltani —

I think since it's not a straightforward thing to get the length of a BucketingSampler, I suggest using a global progress bar based on MAX_STEPS.

from lightning.pytorch.callbacks import ProgressBar
from tqdm import tqdm

MAX_STEPS=10_000

class GlobalProgressBar(ProgressBar):
    def __init__(self):
        super().__init__()  # don't forget this :)
        self.enable = True
        self.pbar = tqdm(total=MAX_STEPS)

    def disable(self):
        self.enable = False

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)  # don't forget this :)
        self.pbar.update(1)
        self.pbar.set_description_str(f"Epoch: {trainer.current_epoch + 1}")
        self.pbar.set_postfix_str(f"loss={outputs['loss']:.2f}")

    def on_train_end(self, trainer, pl_module):
        self.pbar.close()

trainer = pl.Trainer(
    max_steps=MAX_STEPS,
    callbacks=[ GlobalProgressBar() ]
)
fauxneticien commented 1 year ago
  1. TokenCollator

As a minor technical note, I don't think you need any of // (and actually even ) for CTC models.

Ah, I see — thanks @pzelasko ! I think <pad> is still needed for the CTCLoss, right? At least that's my interpretation of the first argument to CTCLoss, unless <pad> and the CTC blank are functionally different.

Would you be open to me submitting a pull request which adds an add_unk argument to TokenCollator? See here for a proposal: https://gist.github.com/fauxneticien/9976752d7c11619c720e99d6ef8e1d7a/revisions

  1. Global MVN

Additionally, for the next tutorial, I'm trying to figure out how to use Global MVN with OnTheFlyFeatures (i.e. not computed and stored). Here's my initial attempt:

from lhotse.recipes import download_librispeech, prepare_librispeech
from lhotse.dataset import OnTheFlyFeatures
from lhotse import Fbank, FbankConfig

from lhotse.dataset import BucketingSampler

from lhotse import CutSet

from torch.utils.data import DataLoader

class MinimalASRDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, global_stats):
        self.extractor = OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
        self.tokenizer = tokenizer
        self.global_stats = { stat:torch.tensor(values) for (stat,values) in global_stats.items() }

    def __getitem__(self, cuts: CutSet) -> dict:
        cuts = cuts.sort_by_duration()

        feats, feat_lens = self.extractor(cuts)
        feats = (feats - self.global_stats["norm_means"]) / self.global_stats['norm_stds']

        tokens, token_lens = self.tokenizer(cuts)
        return {"inputs_padded": feats, "input_lengths": feat_lens, "labels_padded": tokens, "label_lengths": token_lens}

libri = prepare_librispeech(corpus_dir="LibriSpeech", output_dir="data/")

cuts_train = CutSet.from_manifests(**libri["train-clean-5"])

tokenizer = TokenCollater(cuts_train)
global_stats = cuts_train.compute_global_feature_stats(max_cuts=1000, extractor=Fbank(FbankConfig(num_mel_bins=80)))

train_sampler = BucketingSampler(cuts_train, max_duration=300, shuffle=True, drop_last=True)

train_loader = DataLoader(
    MinimalASRDataset(tokenizer, global_stats),
    sampler=train_sampler,
    batch_size=None,
    num_workers=1
)

# Get a batch
batch = next(iter(train_loader))

I'm sure there's a better way. I see in K2SpeechRecognitionDataset, you're supposed to be able to just pass in an array of transforms but I'm unsure how the pre-computed global MVN is meant to be applied to OnTheFly features.

mrezasoltani commented 1 year ago

Thanks. But I think this doesn't work either. Even with this global pbar, it complains about the length. Again this happens with DynamicBucketingSampler(), which is needed for having lazy CutSet.

fauxneticien commented 1 year ago

Hi @mrezasoltani —

Here's some dummy code of my 'hack' around, taking advantage of two facts:

A bit hacky I guess but any suggestions are welcome...

import time
import torch
import lightning.pytorch as pl

from lightning.pytorch.callbacks import ProgressBar

from lhotse import CutSet
from lhotse.recipes import download_librispeech, prepare_librispeech
from lhotse.dataset.sampling import DynamicBucketingSampler

from torch.utils.data import DataLoader
from tqdm import tqdm

download_librispeech(dataset_parts="mini_librispeech")
libri = prepare_librispeech(corpus_dir="LibriSpeech", output_dir="data/")

cuts_train = CutSet.from_manifests(**libri["train-clean-5"])
cuts_valid = CutSet.from_manifests(**libri["dev-clean-2"])

train_sampler = DynamicBucketingSampler(cuts_train, max_duration=100, shuffle=True, drop_last=True)
valid_sampler = DynamicBucketingSampler(cuts_valid,  max_duration=100, shuffle=False, drop_last=True)

MAX_STEPS=1000

class MinimalASRDataset(torch.utils.data.Dataset):
    def __getitem__(self, cuts: CutSet) -> dict:
        cuts = cuts.sort_by_duration()
        return cuts

train_loader = DataLoader(
    MinimalASRDataset(),
    sampler=train_sampler,
    batch_size=None,
    num_workers=1
)

valid_loader = DataLoader(
    MinimalASRDataset(),
    sampler=valid_sampler,
    batch_size=None,
    num_workers=1
)

class DummyModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)

    def configure_optimizers(self):
        # required by Trainer, but not relevant for this test
        return torch.optim.Adam(self.parameters())

    def training_step(self, batch, batch_idx):
        return None

    def validation_step(self, batch, batch_idx):
        # Slow down loop to see progress bar in action
        time.sleep(0.01)
        return None    

class GlobalProgressBar(ProgressBar):
    def __init__(self):
        super().__init__()  # don't forget this :)
        self.enable = True
        self.train_pbar = tqdm(total=MAX_STEPS)

        self.sanity_check_done = False
        self.sanity_check_steps = 0

    def disable(self):
        self.enable = False

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)  # don't forget this :)
        self.train_pbar.update(1)
        self.train_pbar.set_description_str(f"Epoch: {trainer.current_epoch + 1}")

    def on_train_end(self, trainer, pl_module):
        self.train_pbar.close()

    def on_sanity_check_end(self, trainer, pl_module):
        self.sanity_check_done = True
        self.valid_pbar.close()

    def on_validation_start(self, trainer, pl_module):
        if not self.sanity_check_done:
            self.valid_pbar = tqdm(desc="Running full epoch to estimate number of validation batches... ")
        else:
            self.valid_pbar = tqdm(desc="Running validation", total=self.sanity_check_steps)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx)  # don't forget this :)

        if not self.sanity_check_done:
            self.sanity_check_steps += 1
        else:
            self.valid_pbar.update(1)

    def on_validation_end(self, trainer, pl_module):
        self.valid_pbar.close()

trainer = pl.Trainer(
    max_steps=MAX_STEPS,
    accelerator="cpu", 
    devices=1,
    enable_checkpointing=False,
    enable_model_summary=False,
    logger=False,
    strategy="ddp",
    # Prevent Lightning from replacing Lhotse's DDP-compatible sampler
    use_distributed_sampler=False,
    callbacks=[ GlobalProgressBar() ],
    # Set this to -1 to run entire epoch through validation set
    num_sanity_val_steps=-1
)

trainer.fit(
    DummyModel(),
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader
)
pzelasko commented 1 year ago

Regarding the progress bar issues, I'm sure that such a comprehensive framework as PT Lightning has some way to support datasets that have unknown length and wouldn't force you to hack the progress bar. Unfortunately, I don't know PT Lightning very well so I can't suggest a solution.

Ah, I see — thanks @pzelasko ! I think <pad> is still needed for the CTCLoss, right? At least that's my interpretation of the first argument to CTCLoss, unless <pad> and the CTC blank are functionally different.

You shouldn't need <pad> for CTCLoss, I think it's meant mostly for TTS models where text is the input and possibly also for attention decoder ASR.

Would you be open to me submitting a pull request which adds an add_unk argument to TokenCollator? See here for a proposal: https://gist.github.com/fauxneticien/9976752d7c11619c720e99d6ef8e1d7a/revisions

You're welcome to submit a PR.

  1. Global MVN

Additionally, for the next tutorial, I'm trying to figure out how to use Global MVN with OnTheFlyFeatures (i.e. not computed and stored).

You would need to estimate MVN stats before starting training, e.g. compute features for ~1000 cuts, save MVN stats, and load them in the training script. But in general you shouldn't need this kind of MVN for most of today's models -- you can either multiply the features by some small number (~0.1) to compress the range, or if you don't care about streaming, apply sth like layer norm (or do nothing -- many models would still work OK).

fauxneticien commented 1 year ago

Progress bar

I searched IterableDataset inside the PyTorch Lightning repo. It looks like if you switch to using a RichProgressBar (instead of the default tqdm) they have some 'support' for "dataloaders that do not define a size (infinite size)" (see comment here). I say 'support' as, understandably, it does not try to infer the size at all, just displays something as opposed to nothing with tqdm:

asciicast

Usage (for those interested)

# If you don't have rich installed
pip install rich
import lightning.pytorch as pl
from lightning.pytorch.callbacks import RichProgressBar

trainer = pl.Trainer(callbacks=[ RichProgressBar() ])

Since it's already written, I'll also submit my hacky version to the Lightning repo to see if any one has better suggestions. (Edit: thread here if anyone else wants to follow)

TokenCollator with add_unk (and maybe add_pad)

You shouldn't need <pad> for CTCLoss, I think it's meant mostly for TTS models where text is the input and possibly also for attention decoder ASR.

Hm, I see. I'll hold off on the pull request and double check this. If <pad> isn't needed, I'll submit a version that also adds an add_pad flag.

Global MVN

But in general you shouldn't need this kind of MVN for most of today's models -- you can either multiply the features by some small number (~0.1) to compress the range, or if you don't care about streaming, apply sth like layer norm (or do nothing -- many models would still work OK).

Oh I see — good to know, thanks! I'll just stick to SpecAugment to demo feature transformations then (will keep this issue open while writing the 2nd tutorial).

mrezasoltani commented 1 year ago

Progress bar

I searched IterableDataset inside the PyTorch Lightning repo. It looks like if you switch to using a RichProgressBar (instead of the default tqdm) they have some 'support' for "dataloaders that do not define a size (infinite size)" (see comment here). I say 'support' as, understandably, it does not try to infer the size at all, just displays something as opposed to nothing with tqdm:

asciicast

Usage (for those interested)

# If you don't have rich installed
pip install rich
import lightning.pytorch as pl
from lightning.pytorch.callbacks import RichProgressBar

trainer = pl.Trainer(callbacks=[ RichProgressBar() ])

Since it's already written, I'll also submit my hacky version to the Lightning repo to see if any one has better suggestions. (Edit: thread here if anyone else wants to follow)

TokenCollator with add_unk (and maybe add_pad)

You shouldn't need <pad> for CTCLoss, I think it's meant mostly for TTS models where text is the input and possibly also for attention decoder ASR.

Hm, I see. I'll hold off on the pull request and double check this. If <pad> isn't needed, I'll submit a version that also adds an add_pad flag.

Global MVN

But in general you shouldn't need this kind of MVN for most of today's models -- you can either multiply the features by some small number (~0.1) to compress the range, or if you don't care about streaming, apply sth like layer norm (or do nothing -- many models would still work OK).

Oh I see — good to know, thanks! I'll just stick to SpecAugment to demo feature transformations then (will keep this issue open while writing the 2nd tutorial).

Thanks, I have not tried the RichProgressBar with DynamicBucketingSampler() yet. I'll give a shot.

I start doing things from scratch