Open csukuangfj opened 2 years ago
Also, I find that sampler's state_dict contains only the current epoch when the checkpoint was saved. It does not say at which batch in the current epoch the checkpoint was saved.
As a result, I have to use https://github.com/k2-fsa/icefall/blob/ae564f91e6981321a715d3ce1ddf5dec5cc21296/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L618
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
to skip specified number of batches when resuming training from a checkpoint, which may take several minutes.
You're right that it doesn't store batch_idx; instead it stored the number of cuts that were already processed (it's inside diagnostics
). It's sufficient for the sampler to correctly restore its state, and you don't need your workaround that requires long waiting times.
As to the warning I'll take another look at it later, it probably shouldn't compare the fields current
and num_cuts
.
Actually I was wrong -- both the number of cuts and batches that was consumed is kept in diagnostics
, so you can read out everything. However the diagnostics are being reset after every epoch -- but I think it makes sense to change it so that they keep accumulating. WDYT?
Also check this PR which addresses some of your other comments https://github.com/lhotse-speech/lhotse/pull/632
It's sufficient for the sampler to correctly restore its state, and you don't need your workaround that requires long waiting times.
The issue is that the batch_idx
in the following line starts from 0 even if we resume training from a checkpoint.
That may cause confusions for users as it seems that it does not pick up the location where it was previously saved.
for batch_idx, batch in enumerate(train_dl):
you’d need to do sth like:
for batch_idx, batch in enumerate(train_dl, start=train_dl.sampler.diagnostics.num_batches_kept)
You will encounter two issues with that:
1) the diagnostics is reset after each epoch, but I will push a fix for that.
2) there is going to be a difference in batch_idx equal to num_workers * prefetch_factor - 1
, because in the script that saved the state dict, dataloader workers already „consumed” the cutset batch from the sampler, but it was not actually used in training „yet”. I am not sure how we could work around it, but it’s probably not a big issue for large datasets.
you’d need to do sth like:
Thanks!
for batch_idx, batch in enumerate(train_dl, start=train_dl.sampler.diagnostics.num_batches_kept)
There is only
https://github.com/lhotse-speech/lhotse/blob/3685d8c6fc8f4e3c773ac4e851b2265f38c05115/lhotse/dataset/sampling/base.py#L387
and there is no num_batches_kept
.
Also, I find that train_dl.sampler.diagnostics.num_kept_batches
is always 0 during training, at least for the first several batches. Is that expected?
Also, if I use start
in enumerate, I think it changes only batch_idx
, but it still returns the 0-th element from the dataloader.
>>> a = list(range(10))
>>> a
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> for i, v in enumerate(a, start=3): print(i, v)
...
3 0
4 1
5 2
6 3
7 4
8 5
9 6
10 7
11 8
12 9
I am also noticing the loss being better after we re-load from a checkpoint, aroun 0.05->0.04. I suspect something about the SpecAugment settings may have changed. It is not just a transient issue, it stays lower. I'm a bit concerned that this feature might be a bug farm.
Yes I think you’re right about it potentially being a bug farm.
I looked into it a bit yesterday, BucketingSampler and ZipSampler need some fixes to support it correctly. Regarding SpecAugment I don’t know what could have changed — other than the RNG state being different than at the point that the training stopped.
I will commit a fix later that should resolve the bucketing sampler issue.
I get the following warning while trying to use https://github.com/k2-fsa/icefall/pull/259 to restore the state dict of a sampler from a checkpoint.
Related code is listed below: https://github.com/k2-fsa/icefall/blob/ae564f91e6981321a715d3ce1ddf5dec5cc21296/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py#L300
https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L427
https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L800