lhotse-speech / lhotse

Tools for handling speech data in machine learning projects.
https://lhotse.readthedocs.io/en/latest/
Apache License 2.0
956 stars 220 forks source link

Got warnings when loading sampler's state_dict #631

Open csukuangfj opened 2 years ago

csukuangfj commented 2 years ago

I get the following warning while trying to use https://github.com/k2-fsa/icefall/pull/259 to restore the state dict of a sampler from a checkpoint.

lhotse/dataset/sampling/simple.py:144: UserWarning: SimpleCutSampler.load_state_dict():
 Inconsistent time_constraint:
expected TimeConstraint(max_duration=10, max_samples=None, max_frames=None, current=0, num_cuts=0)
received TimeConstraint(max_duration=10, max_samples=None, max_frames=None, current=32.968312499999996, num_cuts=2)

Related code is listed below: https://github.com/k2-fsa/icefall/blob/ae564f91e6981321a715d3ce1ddf5dec5cc21296/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py#L300

        if sampler_state_dict is not None:
            logging.info("Loading sampler state dict")
            train_sampler.load_state_dict(sampler_state_dict)

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L427

    save_checkpoint_impl(
        filename=filename,
        model=model,
        params=params,
        optimizer=optimizer,
        sampler=sampler,
        rank=rank,
    )

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L800

    if checkpoints and "sampler" in checkpoints:
        sampler_state_dict = checkpoints["sampler"]
    else:
        sampler_state_dict = None

    train_dl = librispeech.train_dataloaders(
        train_cuts, sampler_state_dict=sampler_state_dict
    )
csukuangfj commented 2 years ago

Also, I find that sampler's state_dict contains only the current epoch when the checkpoint was saved. It does not say at which batch in the current epoch the checkpoint was saved.

See https://github.com/lhotse-speech/lhotse/blob/b3f219407438b86d1a23f8d47f60f55b7709d1d9/lhotse/dataset/sampling/base.py#L132-L139

https://github.com/lhotse-speech/lhotse/blob/b3f219407438b86d1a23f8d47f60f55b7709d1d9/lhotse/dataset/sampling/simple.py#L112-L119

https://github.com/lhotse-speech/lhotse/blob/b3f219407438b86d1a23f8d47f60f55b7709d1d9/lhotse/dataset/sampling/bucketing.py#L217-L230


As a result, I have to use https://github.com/k2-fsa/icefall/blob/ae564f91e6981321a715d3ce1ddf5dec5cc21296/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L618

    cur_batch_idx = params.get("cur_batch_idx", 0)

    for batch_idx, batch in enumerate(train_dl):
        if batch_idx < cur_batch_idx:
            continue
        cur_batch_idx = batch_idx

to skip specified number of batches when resuming training from a checkpoint, which may take several minutes.

pzelasko commented 2 years ago

You're right that it doesn't store batch_idx; instead it stored the number of cuts that were already processed (it's inside diagnostics). It's sufficient for the sampler to correctly restore its state, and you don't need your workaround that requires long waiting times.

pzelasko commented 2 years ago

As to the warning I'll take another look at it later, it probably shouldn't compare the fields current and num_cuts.

pzelasko commented 2 years ago

Actually I was wrong -- both the number of cuts and batches that was consumed is kept in diagnostics, so you can read out everything. However the diagnostics are being reset after every epoch -- but I think it makes sense to change it so that they keep accumulating. WDYT?

Also check this PR which addresses some of your other comments https://github.com/lhotse-speech/lhotse/pull/632

csukuangfj commented 2 years ago

It's sufficient for the sampler to correctly restore its state, and you don't need your workaround that requires long waiting times.

The issue is that the batch_idx in the following line starts from 0 even if we resume training from a checkpoint. That may cause confusions for users as it seems that it does not pick up the location where it was previously saved.

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L620

    for batch_idx, batch in enumerate(train_dl):
pzelasko commented 2 years ago

you’d need to do sth like:

for batch_idx, batch in enumerate(train_dl, start=train_dl.sampler.diagnostics.num_batches_kept)

You will encounter two issues with that: 1) the diagnostics is reset after each epoch, but I will push a fix for that. 2) there is going to be a difference in batch_idx equal to num_workers * prefetch_factor - 1, because in the script that saved the state dict, dataloader workers already „consumed” the cutset batch from the sampler, but it was not actually used in training „yet”. I am not sure how we could work around it, but it’s probably not a big issue for large datasets.

csukuangfj commented 2 years ago

you’d need to do sth like:

Thanks!

csukuangfj commented 2 years ago

for batch_idx, batch in enumerate(train_dl, start=train_dl.sampler.diagnostics.num_batches_kept)

There is only https://github.com/lhotse-speech/lhotse/blob/3685d8c6fc8f4e3c773ac4e851b2265f38c05115/lhotse/dataset/sampling/base.py#L387 and there is no num_batches_kept.

Also, I find that train_dl.sampler.diagnostics.num_kept_batches is always 0 during training, at least for the first several batches. Is that expected?

csukuangfj commented 2 years ago

Also, if I use start in enumerate, I think it changes only batch_idx, but it still returns the 0-th element from the dataloader.

>>> a = list(range(10))
>>> a
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> for i, v in enumerate(a, start=3): print(i, v)
...
3 0
4 1
5 2
6 3
7 4
8 5
9 6
10 7
11 8
12 9
danpovey commented 2 years ago

I am also noticing the loss being better after we re-load from a checkpoint, aroun 0.05->0.04. I suspect something about the SpecAugment settings may have changed. It is not just a transient issue, it stays lower. I'm a bit concerned that this feature might be a bug farm.

pzelasko commented 2 years ago

Yes I think you’re right about it potentially being a bug farm.

I looked into it a bit yesterday, BucketingSampler and ZipSampler need some fixes to support it correctly. Regarding SpecAugment I don’t know what could have changed — other than the RNG state being different than at the point that the training stopped.

I will commit a fix later that should resolve the bucketing sampler issue.