lhotse-speech / lhotse

Tools for handling speech data in machine learning projects.
https://lhotse.readthedocs.io/en/latest/
Apache License 2.0
936 stars 214 forks source link

Too many times of warning about time_constraint.exceeded(), and training stops quite early . #1269

Closed kobenaxie closed 8 months ago

kobenaxie commented 8 months ago

I installed the lasted lhotse(commit:c678849), and init a dataloader for webdataset like

# webdataset DDP example
cuts = CutSet.from_webdataset(
        ...,
        split_by_node=True,
        split_by_worker=True,
        shuffle_shards=True,
) 
sampler = DynamicBucketingSampler(
         cuts,
         max_duration=200,
         shuffle=True,
         num_buckets=30,
         buffer_size=30 * 200,
         shuffle_buffer_size=30 * 500,
         drop_last=True,
         rank=0,
         world_size=1,
     )
dataset = IterableDatasetWrapper(dataset=K2SpeechRecognitionDataset(...), sampler=sampler)
dloader = DataLoader(
        dataset,
        batch_size=None,
        num_workers=num_workers,
        worker_init_fn=make_worker_init_fn(
            rank=rank,
            world_size=world_size,
        ),
    )

The training process stoped quite early with many time_constraint.exceeded() warnings, it seems most of the data were skipped. https://github.com/lhotse-speech/lhotse/blob/c6788492bff2b6d431ccdef2e19041da947907fa/lhotse/dataset/sampling/dynamic.py#L325-L330

and some warnings from IterableDatasetWrapper https://github.com/lhotse-speech/lhotse/blob/c6788492bff2b6d431ccdef2e19041da947907fa/lhotse/dataset/iterable_dataset.py#L69-L76 which caused by https://github.com/lhotse-speech/lhotse/blob/c6788492bff2b6d431ccdef2e19041da947907fa/lhotse/dataset/sampling/base.py#L104-L108 , whitch is introduced in this pr

desh2608 commented 8 months ago

Can you show some statistics for your cuts using cuts.describe()?

pzelasko commented 8 months ago

The training process stoped quite early with many time_constraint.exceeded() warnings

I just realized that the exceeded warning can generate some false positives if the last sampled cut in a mini-batch happens to be the longest cut... I don't expect this to be an issue in practice as it would likely exceed only by a small number of seconds, proportional to diff of last_sampled_duration - longest_seen (which'd be especially small if you use bucketing). I made a PR to remove the false positives -- we'll only emit the warning if the mini-batch has a single example https://github.com/lhotse-speech/lhotse/pull/1270

it seems most of the data were skipped.

Try to increase the buffer_size, with 6000 it looks like you're on the lower edge, if your data is heavily skewed towards short utterances it's possible you're unable to populate 30 buckets to get a full max_duration batch size and drop_last=True will cause the sampler to exit early if it can't form a full batch; however

and some warnings from IterableDatasetWrapper

This looks like it might be the real issue here, it'd cause the sampler to drop a lot of batches... I'll revert that change for now. @yuekaizhang could you elaborate if this is really necessary for Whisper v3 support in https://github.com/lhotse-speech/lhotse/pull/1260? Maybe we can figure out another way as this turned out to be a breaking change.

pzelasko commented 8 months ago

I figured out a fix for the sampler in #1270 which keeps the changes in #1260 but doesn't apply them if you explicitly set rank and world_size, which will support both use cases.

kobenaxie commented 8 months ago

Same code with lhotse==1.19.2 installed by pip (pip install lhotse), the training process is normal.

yuekaizhang commented 8 months ago

sary for Whisper v3 support in #1260? Maybe we can figure out another way as this turned out to be a breaking change.

Sorry for the bug. The change is intended for deepspeed usage, since when we use deepspeed and call torchrun train.py or deepspeed train.py, the model would not be a DDP instance, however, the torchrun launcher would set the rank and world_size automaticlly.

kobenaxie commented 8 months ago

Can you show some statistics for your cuts using cuts.describe()?

Cut statistics:
╒═══════════════════════════╤════════════╕
│ Cuts count:               │ 1009223    │
├───────────────────────────┼────────────┤
│ Total duration (hh:mm:ss) │ 1000:45:42 │
├───────────────────────────┼────────────┤
│ mean                      │ 3.6        │
├───────────────────────────┼────────────┤
│ std                       │ 1.5        │
├───────────────────────────┼────────────┤
│ min                       │ 0.0        │
├───────────────────────────┼────────────┤
│ 25%                       │ 2.4        │
├───────────────────────────┼────────────┤
│ 50%                       │ 3.3        │
├───────────────────────────┼────────────┤
│ 75%                       │ 4.4        │
├───────────────────────────┼────────────┤
│ 99%                       │ 8.0        │
├───────────────────────────┼────────────┤
│ 99.5%                     │ 8.6        │
├───────────────────────────┼────────────┤
│ 99.9%                     │ 10.2       │
├───────────────────────────┼────────────┤
│ max                       │ 19.3       │
├───────────────────────────┼────────────┤
│ Recordings available:     │ 1009223    │
├───────────────────────────┼────────────┤
│ Features available:       │ 0          │
├───────────────────────────┼────────────┤
│ Supervisions available:   │ 1009223    │
╘═══════════════════════════╧════════════╛
Speech duration statistics:
╒══════════════════════════════╤════════════╤══════════════════════╕
│ Total speech duration        │ 1000:45:42 │ 100.00% of recording │
├──────────────────────────────┼────────────┼──────────────────────┤
│ Total speaking time duration │ 1000:45:42 │ 100.00% of recording │
├──────────────────────────────┼────────────┼──────────────────────┤
│ Total silence duration       │ 00:00:00   │ 0.00% of recording   │
╘══════════════════════════════╧════════════╧══════════════════════╛

It is Aishell2 dataset, I think it is nothing to do with the data.

kobenaxie commented 8 months ago

The lasted lhotse(1.20.0.dev0+git.9f4bfa1.clean) works normally now.