Open tz301 opened 3 months ago
Could you tell us which script you are using?
Have you changed any code or just used the original code from us?
Also, please tell us whether you are using a streaming or a non-streaming model and what is the typical wave duration of your test file.
It would be great if you can post the complete decoding command.
Could you tell us which script you are using?
Have you changed any code or just used the original code from us?
Also, please tell us whether you are using a streaming or a non-streaming model and what is the typical wave duration of your test file.
It would be great if you can post the complete decoding command.
Hi @csukuangfj,
I have export the model using torch.jit.export and write my one decode code, which is used in offline scenario. However, the core decoding code is from icefall.
Actually my wave is usually long than 1 minite, but the max duration of wave file for asr is 20s (which is force cut by energy vad).
I have attached code below.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import math
from argparse import ArgumentParser
from pathlib import Path
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import greedy_search_batch
from icefall.lexicon import Lexicon
from pydub import AudioSegment
from torch.nn.utils.rnn import pad_sequence
_LOGGER = logging.getLogger(__name__)
class Decoder:
def __init__(self, lang_dir):
self._sp = spm.SentencePieceProcessor()
self._sp.load(str(lang_dir / 'bpe.model'))
self._lexicon = Lexicon(lang_dir)
self.blank_id = self._lexicon.token_table['<blk>']
self.vocab_size = max(self._lexicon.tokens) + 1
_LOGGER.info('Decoder Init Succeed.')
@property
def lexicon(self):
return self._lexicon
def decode_tokens(self, tokens):
token_table = self._lexicon.token_table
return self._sp.decode([token_table[idx] for idx in tokens])
def decode(self, model, encoder_out, encoder_out_lens):
pred = list()
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
pred.append(self.decode_tokens(hyp_tokens[i]))
return pred
class Asr:
def __init__(self, model_dir):
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
model_dir = Path(model_dir)
self._device = torch.device('cpu')
self._feat_extractor = self.__get_feat_extractor()
self._model = self._init_model(model_dir)
self._decoder = Decoder(model_dir / 'lang',)
_LOGGER.info('Asr Model init succeed.')
def __get_feat_extractor(self):
opts = kaldifeat.FbankOptions()
opts.device = self._device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 8000
opts.mel_opts.num_bins = 80
feat_extractor = kaldifeat.Fbank(opts)
_LOGGER.info('Fbank feat extractor init succeed.')
return feat_extractor
def _init_model(self, model_dir):
model = torch.jit.load(model_dir / 'model.pt')
model.eval()
model.to(self._device)
_LOGGER.info('ZipFormer Load succeed.')
return model
@staticmethod
def _read_wav_files(wav_files):
waves = list()
for wav_file in wav_files:
wave, sample_rate = torchaudio.load(wav_file)
waves.append(wave[0].contiguous())
return waves
def _get_feature(self, wav_files):
waves = self._read_wav_files(wav_files)
waves = [w.to(self._device) for w in waves]
features = self._feat_extractor(waves)
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(
[f.size(0) for f in features], device=self._device
)
return features, feature_lengths
def _encode(self, features, feature_lengths):
"""编码.
Args:
features: 特征.
feature_lengths: 特征长度.
Returns:
编码输出和编码输出长度.
"""
encoder_out, encoder_out_lens = self._model.encoder(
features=features,
feature_lengths=feature_lengths
)
return encoder_out, encoder_out_lens
def recognize(self, wav_files):
features, feature_lengths = self._get_feature(wav_files)
encoder_out, encoder_out_lens = self._encode(features, feature_lengths)
texts = self._decoder.decode(self._model, encoder_out, encoder_out_lens)
return texts
def _main():
parser = ArgumentParser('recognize')
parser.add_argument('model_dir', type=Path, help='model directory')
parser.add_argument('wav_dir', type=Path, help='wav directory')
args = parser.parse_args()
asr = Asr(args.model_dir)
for wav_file in args.wav_dir.iterdir():
duration = AudioSegment.from_file(wav_file).duration_seconds
text = asr.recognize([wav_file])[0]
_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')
if __name__ == '__main__':
LOGGER_FORMAT = ('%(asctime)s.%(msecs)03d - %(name)s:%(lineno)s '
'- %(funcName)s() - %(levelname)s - %(message)s')
logging.basicConfig(format=LOGGER_FORMAT, level=logging.INFO)
_main()
I see.
Please use
@torch.no_grad() def _main():
as what we are doing in decoding.
Hi @csukuangfj,
Yeah, add @torch.no_grad() seems work, the memory decrease from 2.5G to 2.1G.
I'm not sure is it normal to use ~2G memory, or any other idea to decrease it?
Could you post your updated code?
@csukuangfj
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import math
from argparse import ArgumentParser
from pathlib import Path
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import greedy_search_batch
from icefall.lexicon import Lexicon
from pydub import AudioSegment
from torch.nn.utils.rnn import pad_sequence
_LOGGER = logging.getLogger(__name__)
class Decoder:
def __init__(self, lang_dir):
self._sp = spm.SentencePieceProcessor()
self._sp.load(str(lang_dir / 'bpe.model'))
self._lexicon = Lexicon(lang_dir)
self.blank_id = self._lexicon.token_table['<blk>']
self.vocab_size = max(self._lexicon.tokens) + 1
_LOGGER.info('Decoder Init Succeed.')
@property
def lexicon(self):
return self._lexicon
def decode_tokens(self, tokens):
token_table = self._lexicon.token_table
return self._sp.decode([token_table[idx] for idx in tokens])
def decode(self, model, encoder_out, encoder_out_lens):
pred = list()
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
pred.append(self.decode_tokens(hyp_tokens[i]))
return pred
class Asr:
def __init__(self, model_dir):
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
model_dir = Path(model_dir)
self._device = torch.device('cpu')
self._feat_extractor = self.__get_feat_extractor()
self._model = self._init_model(model_dir)
self._decoder = Decoder(model_dir / 'lang',)
_LOGGER.info('Asr Model init succeed.')
def __get_feat_extractor(self):
opts = kaldifeat.FbankOptions()
opts.device = self._device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 8000
opts.mel_opts.num_bins = 80
feat_extractor = kaldifeat.Fbank(opts)
_LOGGER.info('Fbank feat extractor init succeed.')
return feat_extractor
def _init_model(self, model_dir):
model = torch.jit.load(model_dir / 'model.pt')
model.eval()
model.to(self._device)
_LOGGER.info('ZipFormer Load succeed.')
return model
@staticmethod
def _read_wav_files(wav_files):
waves = list()
for wav_file in wav_files:
wave, sample_rate = torchaudio.load(wav_file)
waves.append(wave[0].contiguous())
return waves
def _get_feature(self, wav_files):
waves = self._read_wav_files(wav_files)
waves = [w.to(self._device) for w in waves]
features = self._feat_extractor(waves)
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(
[f.size(0) for f in features], device=self._device
)
return features, feature_lengths
def _encode(self, features, feature_lengths):
"""编码.
Args:
features: 特征.
feature_lengths: 特征长度.
Returns:
编码输出和编码输出长度.
"""
encoder_out, encoder_out_lens = self._model.encoder(
features=features,
feature_lengths=feature_lengths
)
return encoder_out, encoder_out_lens
def recognize(self, wav_files):
features, feature_lengths = self._get_feature(wav_files)
encoder_out, encoder_out_lens = self._encode(features, feature_lengths)
texts = self._decoder.decode(self._model, encoder_out, encoder_out_lens)
return texts
@torch.no_grad()
def _main():
parser = ArgumentParser('recognize')
parser.add_argument('model_dir', type=Path, help='model directory')
parser.add_argument('wav_dir', type=Path, help='wav directory')
args = parser.parse_args()
asr = Asr(args.model_dir)
for wav_file in args.wav_dir.iterdir():
duration = AudioSegment.from_file(wav_file).duration_seconds
text = asr.recognize([wav_file])[0]
_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')
if __name__ == '__main__':
LOGGER_FORMAT = ('%(asctime)s.%(msecs)03d - %(name)s:%(lineno)s '
'- %(funcName)s() - %(levelname)s - %(message)s')
logging.basicConfig(format=LOGGER_FORMAT, level=logging.INFO)
_main()
Does the memory grow linearly from 0 to 2.1GB and then keep at 2.1 GB?
_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')
Could you give the output of the above log? What is the max value of duration?
Does the memory grow linearly from 0 to 2.1GB and then keep at 2.1 GB?
_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')
Could you give the output of the above log? What is the max value of duration?
The max duration is 20s in my wav files.
The memory first grow to around 1.5G for the first few wavs, then grow slowly to 2.1G and keep at 2.1G.
The first few wavs (around 5 wavs) is extremely slow, may cost 1~2 minute to finish decode. I'm not sure if it's normal that the warm up for this asr model need this time.
With zipformer I can get good performance.
Currently, when I decode on CPU one by one (not using batch), the memory cost will go to 2.5G. The token size is 5000 and using greedy_search to decode. I try to reduce it to 4000, but the memory cost seem not decrease much.
Any idea to reduce it without obvious performance degrade?
Some model configurations below: num-encoder-layers=2,2,2,3,2,2 downsampling-factor=1,2,4,8,4,2 feedforward-dim=256,384,512,768,512,384 num-heads=4,4,4,4,4,4 encoder-dim=192,256,256,384,256,256 query-head-dim=24 value-head-dim=8 pos-head-dim=4 pos-dim=24 encoder-unmasked-dim=192,192,256,256,256,192 cnn-module-kernel=31,31,15,15,15,31 decoder-dim=256 joiner-dim=256