k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
931 stars 295 forks source link

Possible issue with randomness in augmentations #270

Closed pzelasko closed 7 months ago

pzelasko commented 2 years ago

I've just noticed that adding this function:

def worker_init_fn(worker_id: int):
    lhotse.utils.fix_random_seed(42 + worker_id)

in dataloader: DataLoader(..., worker_init_fn=worker_init_fn)

results in (significantly) different training losses, for me they were much higher. I think that in the current training there might be some repetitive patterns in augmentation which make it easier for the model to train against. Would be great if somebody here could try it out on some recent model and confirm if I'm right or not.

pzelasko commented 2 years ago

… and for multi GPU training it might be needed to also add the rank value * 100 or sth similar.

danpovey commented 2 years ago

Hm. I think this is problematic because it does not take into account the epoch, and would not be at all repeatable. (?) is worker_id a process-id, or the rank?

danpovey commented 2 years ago

... Also, at the start of every epoch in the main process we do fix_random_seed(), and we know this makes a difference because of our recent experience with the checkpoint-code changes. So I'm a bit confused now. Perhaps the fix_random_seed() affected the SpecAug but not the data selection, or something like that?

danpovey commented 2 years ago

I'm trying setting the seed like this:

        # 'seed' is derived from the current random state, which will have previously been                                                                                                                              
        # set in the main process.                                                                                                                                                                                      
        seed = torch.randint(0, 100000, ()).item()
        def worker_init_fn(worker_id: int):
            lhotse.utils.fix_random_seed(seed + worker_id)
danpovey commented 2 years ago

There appears to be a problem with this, as in lhotse's __init.py__ it does from augmentation import *, which imports utils, but that makes lhotse/utils .py unavailable.

danpovey commented 2 years ago

I think the current version of lhotse is not functional for this reason. I'm trying to fix it but there's quite a bit to untangle, due to extensive use of from something import *. I managed to fix it enough to get my experiment to run, this way:

diff --git a/lhotse/augmentation/utils.py b/lhotse/augmentation/utils.py
index c4f342a..afef7c1 100644
--- a/lhotse/augmentation/utils.py
+++ b/lhotse/augmentation/utils.py
@@ -69,3 +69,5 @@ def convolve1d(signal: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
     result = irfft(f_result, n=fast_ftt_size)

     return result[:padded_size]
+
+from lhotse.utils import * # Terrible hack.

... BTW I installed with pip install -e '.[dev]', not sure if this matters. Also, the way I was using this was: import lhotse ...and in code: .. lhotse.utils.fix_random_seed(); Whereas even using the current version of lhotse, I am able to successfully do: from lhotse.utils import fix_random_seed It seems lhotse's __init__.py is problematic right now.

pzelasko commented 2 years ago

Hm. I think this is problematic because it does not take into account the epoch, and would not be at all repeatable. (?)

We could use lhotse.utils.fix_random_seed(seed + epoch + 100 * worker_id + 10000 * gpu_id) for dataloading workers to make it epoch dependent. We also need to account for GPU id so that each GPU sees data transformed differently.

is worker_id a process-id, or the rank?

worker_id is an int between (0, num_workers - 1), each dataloder worker is a separate process which can set its own global RNG states.

... Also, at the start of every epoch in the main process we do fix_random_seed(), and we know this makes a difference because of our recent experience with the checkpoint-code changes. So I'm a bit confused now. Perhaps the fix_random_seed() affected the SpecAug but not the data selection, or something like that?

My understanding is that:

There appears to be a problem with this, as in lhotse's init.py it does from augmentation import *, which imports utils, but that makes lhotse/utils .py unavailable.

I'll look into it and find a fix.

danpovey commented 2 years ago

Rather than finding the gpu_id, I think my approach of just calling torch.randint(()) to get the seed may be easier, as it is naturally affected by higher-level setting of seeds. BTW, I did run with the new seed-setting code that I showed above (modified to use from lhotse.utils import fix_random_seed. WER and train loss were perhaps a little better, but I'm not 100% convinced it wasn't just random. This was with 1 GPU, and the only augmentation was SpecAug plus Musan.

danpovey commented 2 years ago

After running this for more epochs on a train-clean-100 setup, WER changes are: baseline=6.82/17.69, fixed=6.76/17.38, on test-clean/test-other, tested at --epoch 39 --avg 10. This is starting to get large enough that I would consider it non-random. @csukuangfj we might want to apply this in icefall more generally. I am of course incorporating in my rework of the recipe, but this may not be merged for another week so perhaps we could merge this fix earlier? OTOH, a case for delaying is that applying the fix may confuse experimental comparisons a bit. Up to you.

 git diff rework2h{,_randloader}
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index a460c8e..3efe7ec 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -22,6 +22,8 @@ import logging
 from functools import lru_cache
 from pathlib import Path
 from typing import Any, Dict, Optional
+import torch
+from lhotse.utils import fix_random_seed

 from lhotse import CutSet, Fbank, FbankConfig, load_manifest
 from lhotse.dataset import (
@@ -301,12 +303,19 @@ class LibriSpeechAsrDataModule:
             logging.info("Loading sampler state dict")
             train_sampler.load_state_dict(sampler_state_dict)

+        # 'seed' is derived from the current random state, which will have previously been
+        # set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        def worker_init_fn(worker_id: int):
+            fix_random_seed(seed + worker_id)
+
         train_dl = DataLoader(
             train,
             sampler=train_sampler,
             batch_size=None,
             num_workers=self.args.num_workers,
             persistent_workers=False,
+            worker_init_fn=worker_init_fn,
         )

         return train_dl
pzelasko commented 2 years ago

Thanks for looking into it.