k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
890 stars 287 forks source link

on-the-fly fbank feats #666

Open Cescfangs opened 1 year ago

Cescfangs commented 1 year ago

Hey guys, I notice there‘s on-the-fly feats in asr_datamodule.py: https://github.com/k2-fsa/icefall/blob/32de2766d591d2e1a77c06a40d2861fb1bbcd3ad/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py#L279-L298

However, I didn't find any recipe using that feats, how could I using on-the-fly feats instead of making fbank first(I'm using large dataset, making fbank locally would not be possible due to disk capacity ). Btw, does icefall support reading feats.scp directory like espnet does?(we have precomputed kaldi fbank)

csukuangfj commented 1 year ago

Pass

--on-the-fly-feats true

to train.py will do.

You can use

train.py --help

to view the usages.

Cescfangs commented 1 year ago

Thanks for the reply, so I could just skip making fbank stages in prepare.sh?

csukuangfj commented 1 year ago

Yes, I believe so.

Cescfangs commented 1 year ago

Thanks mate.

Cescfangs commented 1 year ago

@csukuangfj I found on-the-fly feats computation makes training much slower, for example it cost 20 seconds using pre computed kaldi fbank feats for 50 batch iteration and it took about 4 minutes by on-the-fly computation under the same circumstance, I notice you have trained with on-the-fly feats on large datasets(https://github.com/k2-fsa/icefall/pull/312#issuecomment-1096641908), how did you resolve this problem?

csukuangfj commented 1 year ago

Are you using raw waves? Also, is your disk fast?

Cescfangs commented 1 year ago

Are you using raw waves? Also, is your disk fast?

Yes I'm using raw waves and how to check my disk is fast or slow?

Cescfangs commented 1 year ago

BTW, I've trained using raw waves with Espnet, the gpu utility is around 70% which I think is normal , the difference is in Espnet I implement Fbank as a frontend layer(part of model, and running on GPU), so maybe my disk is not the bottleneck?

pzelasko commented 1 year ago

Can you try increasing the number of dataloader workers? Perhaps that’s the bottleneck.

If you want to use fbank as a layer you can modify the code to use https://github.com/lhotse-speech/lhotse/blob/eb9e6b115729697c66c0a7f5f7ba08984b6a1ee5/lhotse/features/kaldi/layers.py#L476

If it turns out to be a slow disk problem you can speed up the IO at the cost of extra copy of data using: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

Cescfangs commented 1 year ago

Can you try increasing the number of dataloader workers? Perhaps that’s the bottleneck.

If you want to use fbank as a layer you can modify the code to use https://github.com/lhotse-speech/lhotse/blob/eb9e6b115729697c66c0a7f5f7ba08984b6a1ee5/lhotse/features/kaldi/layers.py#L476

If it turns out to be a slow disk problem you can speed up the IO at the cost of extra copy of data using: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

@pzelasko thanks for the advice, I have done some experiments: feature_type dataloader workers OnTheFlyFeatures workers runtime(50 batch)
precompute fbank 4 - 20s
KaldifeatFbank 4 0 240s
KaldifeatFbank 4 4 140s
+lhotse.set_caching_enabled(True) 4 4 130s
8 4 80s
8 8 80s
16 4 50s

I'm using KaldifeatFbank because it's compatible to Kaldi

csukuangfj commented 1 year ago

I'm using KaldifeatFbank because it's compatible to Kaldi

Can you try to use GPU to extract features?

KaldifeatFbank supports GPU. If you are using DDP, you can use device="cuda:0", device="cuda:1", etc., to specify the device.

pzelasko commented 1 year ago

You are getting the best gains by increasing dataloader workers so it’s likely an IO bottleneck, using webdataset or Lhotse Shar may help.

BTW the fbank I posted is also compatible with Kaldi. Note that regardless which one you choose, you’ll need to move fbank computation from dloader to training loop to leverage GPU.

Cescfangs commented 1 year ago

@pzelasko I got errors when increase word_size from 1 to 8, can you give me some advice?

"2023-08-09T09:46:00+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:00+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:01+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:02+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:03+08:00" Traceback (most recent call last):
"2023-08-09T09:46:03+08:00"   File "./pruned_transducer_stateless5_bs/train.py", line 1475, in <module>
"2023-08-09T09:46:03+08:00"     main()
"2023-08-09T09:46:03+08:00"   File "./pruned_transducer_stateless5_bs/train.py", line 1464, in main
"2023-08-09T09:46:03+08:00"     mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
"2023-08-09T09:46:03+08:00"     return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
"2023-08-09T09:46:03+08:00"     while not context.join():
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
"2023-08-09T09:46:03+08:00"     raise ProcessRaisedException(msg, error_index, failed_process.pid)
"2023-08-09T09:46:03+08:00" torch.multiprocessing.spawn.ProcessRaisedException: 
"2023-08-09T09:46:03+08:00" 
"2023-08-09T09:46:03+08:00" -- Process 0 terminated with the following error:
"2023-08-09T09:46:03+08:00" Traceback (most recent call last):
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
"2023-08-09T09:46:03+08:00"     fn(i, *args)
"2023-08-09T09:46:03+08:00"   File "/data1/icefall-master/egs/hik/asr2_audio/pruned_transducer_stateless5_bs/train.py", line 1340, in run
"2023-08-09T09:46:03+08:00"     train_one_epoch(
"2023-08-09T09:46:03+08:00"   File "/data1/icefall-master/egs/hik/asr2_audio/pruned_transducer_stateless5_bs/train.py", line 1043, in train_one_epoch
"2023-08-09T09:46:03+08:00"     for batch_idx, batch in enumerate(train_dl):
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
"2023-08-09T09:46:03+08:00"     return self._get_iterator()
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 384, in _get_iterator
"2023-08-09T09:46:03+08:00"     return _MultiProcessingDataLoaderIter(self)
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1086, in __init__
"2023-08-09T09:46:03+08:00"     self._reset(loader, first_iter=True)
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1119, in _reset
"2023-08-09T09:46:03+08:00"     self._try_put_index()
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1353, in _try_put_index
"2023-08-09T09:46:03+08:00"     index = self._next_index()
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 642, in _next_index
"2023-08-09T09:46:03+08:00"     return next(self._sampler_iter)  # may raise StopIteration
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/base.py", line 261, in __next__
"2023-08-09T09:46:03+08:00"     batch = self._next_batch()
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 237, in _next_batch
"2023-08-09T09:46:03+08:00"     batch = next(self.cuts_iter)
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 360, in __iter__
"2023-08-09T09:46:03+08:00"     ready_buckets = [b for b in self.buckets if is_ready(b)]
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 360, in <listcomp>
"2023-08-09T09:46:03+08:00"     ready_buckets = [b for b in self.buckets if is_ready(b)]
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 351, in is_ready
"2023-08-09T09:46:03+08:00"     tot.add(c[0] if isinstance(c, tuple) else c)
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/base.py", line 350, in add
"2023-08-09T09:46:03+08:00"     self.current += cut.duration
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
"2023-08-09T09:46:03+08:00"     _error_if_any_worker_fails()
"2023-08-09T09:46:03+08:00" RuntimeError: DataLoader worker (pid 76863) is killed by signal: Aborted. 
"2023-08-09T09:46:03+08:00" 
"2023-08-09T09:46:03+08:00" [INFO] recv error: exit status 1
"2023-08-09T09:46:03+08:00" [ERROR] error happends during process: exit status 1
"2023-08-09T09:46:03+08:00" [INFO] still reserved
"2023-08-09T09:46:03+08:00" [INFO] recv flag (false)
"2023-08-09T09:46:03+08:00" [INFO] sleeping
pzelasko commented 1 year ago

Can you reduce the number of workers (especially for on the fly features) and see if it helps?

Cescfangs commented 1 year ago

I'm using KaldifeatFbank because it's compatible to Kaldi

Can you try to use GPU to extract features?

KaldifeatFbank supports GPU. If you are using DDP, you can use device="cuda:0", device="cuda:1", etc., to specify the device.

@csukuangfj I tried using GPU for feature extraction, but it seems that we can't re-initialize CUDA in forked subprocess:

Original Traceback (most recent call last):
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = self.dataset[possibly_batched_index]
  File "/data1/tools/lhotse-master/lhotse/dataset/speech_recognition.py", line 113, in __getitem__
    input_tpl = self.input_strategy(cuts)
  File "/data1/tools/lhotse-master/lhotse/dataset/input_strategies.py", line 380, in __call__
    features_single = self.extractor.extract_batch(
  File "/data1/tools/lhotse-master/lhotse/features/kaldifeat.py", line 84, in extract_batch
    return self.extract(samples=samples, sampling_rate=sampling_rate)
  File "/data1/tools/lhotse-master/lhotse/features/kaldifeat.py", line 125, in extract
    result = self.extractor(samples, chunk_size=self.config.chunk_size)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/kaldifeat/offline_feature.py", line 79, in forward
    features = self.compute(strided, vtln_warp, chunk_size=chunk_size)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/kaldifeat/offline_feature.py", line 135, in compute
    x[end:].to(self_device), vtln_warp
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/cuda/__init__.py", line 207, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
Cescfangs commented 1 year ago

Can you reduce the number of workers (especially for on the fly features) and see if it helps?

Yes I can run at most 8 workers for dataloader and 1 worker for OnTheFlyFeatures setting world_size=8, the average runtime for 50 batch is around 95s, which seems very reasonable according to the table(https://github.com/k2-fsa/icefall/issues/666#issuecomment-1669542121).

feature_type | dataloader workers | OnTheFlyFeatures workers | runtime(50 batch) -- | -- | -- | -- +lhotse.set_caching_enabled(True) | 4 | 4 | 130s   | 8 | 4 | 80s

You are getting the best gains by increasing dataloader workers so it’s likely an IO bottleneck, using webdataset or Lhotse Shar may help.

Is there any icefall recipe using webdataset or Lhotse Shar to follow?

pzelasko commented 1 year ago

AFAIK there's no recipe at this time, but it shouldn't be too involved:

be aware that it will create a full copy of your audio data