k2-fsa / icefall

Apache License 2.0
902 stars 287 forks source link

How to reduce memory when decoding on CPU? #1672

Open tz301 opened 3 months ago

tz301 commented 3 months ago

With zipformer I can get good performance.

Currently, when I decode on CPU one by one (not using batch), the memory cost will go to 2.5G. The token size is 5000 and using greedy_search to decode. I try to reduce it to 4000, but the memory cost seem not decrease much.

Any idea to reduce it without obvious performance degrade?

Some model configurations below: num-encoder-layers=2,2,2,3,2,2 downsampling-factor=1,2,4,8,4,2 feedforward-dim=256,384,512,768,512,384 num-heads=4,4,4,4,4,4 encoder-dim=192,256,256,384,256,256 query-head-dim=24 value-head-dim=8 pos-head-dim=4 pos-dim=24 encoder-unmasked-dim=192,192,256,256,256,192 cnn-module-kernel=31,31,15,15,15,31 decoder-dim=256 joiner-dim=256

csukuangfj commented 3 months ago

Could you tell us which script you are using?

Have you changed any code or just used the original code from us?

Also, please tell us whether you are using a streaming or a non-streaming model and what is the typical wave duration of your test file.

It would be great if you can post the complete decoding command.

tz301 commented 3 months ago

Could you tell us which script you are using?

Have you changed any code or just used the original code from us?

Also, please tell us whether you are using a streaming or a non-streaming model and what is the typical wave duration of your test file.

It would be great if you can post the complete decoding command.

Hi @csukuangfj,

I have export the model using torch.jit.export and write my one decode code, which is used in offline scenario. However, the core decoding code is from icefall.

Actually my wave is usually long than 1 minite, but the max duration of wave file for asr is 20s (which is force cut by energy vad).

I have attached code below.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import math
from argparse import ArgumentParser
from pathlib import Path

import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import greedy_search_batch
from icefall.lexicon import Lexicon
from pydub import AudioSegment
from torch.nn.utils.rnn import pad_sequence

_LOGGER = logging.getLogger(__name__)

class Decoder:

    def __init__(self, lang_dir):
        self._sp = spm.SentencePieceProcessor()
        self._sp.load(str(lang_dir / 'bpe.model'))

        self._lexicon = Lexicon(lang_dir)

        self.blank_id = self._lexicon.token_table['<blk>']
        self.vocab_size = max(self._lexicon.tokens) + 1
        _LOGGER.info('Decoder Init Succeed.')

    def lexicon(self):
        return self._lexicon

    def decode_tokens(self, tokens):
        token_table = self._lexicon.token_table
        return self._sp.decode([token_table[idx] for idx in tokens])

    def decode(self, model, encoder_out, encoder_out_lens):
        pred = list()
        hyp_tokens = greedy_search_batch(
        for i in range(encoder_out.size(0)):
        return pred

class Asr:

    def __init__(self, model_dir):
        model_dir = Path(model_dir)
        self._device = torch.device('cpu')
        self._feat_extractor = self.__get_feat_extractor()
        self._model = self._init_model(model_dir)
        self._decoder = Decoder(model_dir / 'lang',)
        _LOGGER.info('Asr Model init succeed.')

    def __get_feat_extractor(self):
        opts = kaldifeat.FbankOptions()
        opts.device = self._device
        opts.frame_opts.dither = 0
        opts.frame_opts.snip_edges = False
        opts.frame_opts.samp_freq = 8000
        opts.mel_opts.num_bins = 80
        feat_extractor = kaldifeat.Fbank(opts)
        _LOGGER.info('Fbank feat extractor init succeed.')
        return feat_extractor

    def _init_model(self, model_dir):
        model = torch.jit.load(model_dir / 'model.pt')
        _LOGGER.info('ZipFormer Load succeed.')
        return model

    def _read_wav_files(wav_files):
        waves = list()
        for wav_file in wav_files:
            wave, sample_rate = torchaudio.load(wav_file)
        return waves

    def _get_feature(self, wav_files):
        waves = self._read_wav_files(wav_files)
        waves = [w.to(self._device) for w in waves]

        features = self._feat_extractor(waves)
        features = pad_sequence(

        feature_lengths = torch.tensor(
            [f.size(0) for f in features], device=self._device
        return features, feature_lengths

    def _encode(self, features, feature_lengths):

            features: 特征.
            feature_lengths: 特征长度.

        encoder_out, encoder_out_lens = self._model.encoder(
        return encoder_out, encoder_out_lens

    def recognize(self, wav_files):
        features, feature_lengths = self._get_feature(wav_files)
        encoder_out, encoder_out_lens = self._encode(features, feature_lengths)
        texts = self._decoder.decode(self._model, encoder_out, encoder_out_lens)
        return texts

def _main():
    parser = ArgumentParser('recognize')
    parser.add_argument('model_dir', type=Path, help='model directory')
    parser.add_argument('wav_dir', type=Path, help='wav directory')
    args = parser.parse_args()

    asr = Asr(args.model_dir)
    for wav_file in args.wav_dir.iterdir():
        duration = AudioSegment.from_file(wav_file).duration_seconds
        text = asr.recognize([wav_file])[0]
        _LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')

if __name__ == '__main__':
    LOGGER_FORMAT = ('%(asctime)s.%(msecs)03d - %(name)s:%(lineno)s '
                     '- %(funcName)s() - %(levelname)s - %(message)s')
    logging.basicConfig(format=LOGGER_FORMAT, level=logging.INFO)
csukuangfj commented 3 months ago

I see.

Please use

@torch.no_grad() def _main():

as what we are doing in decoding.

tz301 commented 3 months ago

Hi @csukuangfj,

Yeah, add @torch.no_grad() seems work, the memory decrease from 2.5G to 2.1G.

I'm not sure is it normal to use ~2G memory, or any other idea to decrease it?

csukuangfj commented 3 months ago

Could you post your updated code?

tz301 commented 3 months ago


#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import math
from argparse import ArgumentParser
from pathlib import Path

import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import greedy_search_batch
from icefall.lexicon import Lexicon
from pydub import AudioSegment
from torch.nn.utils.rnn import pad_sequence

_LOGGER = logging.getLogger(__name__)

class Decoder:

    def __init__(self, lang_dir):
        self._sp = spm.SentencePieceProcessor()
        self._sp.load(str(lang_dir / 'bpe.model'))

        self._lexicon = Lexicon(lang_dir)

        self.blank_id = self._lexicon.token_table['<blk>']
        self.vocab_size = max(self._lexicon.tokens) + 1
        _LOGGER.info('Decoder Init Succeed.')

    def lexicon(self):
        return self._lexicon

    def decode_tokens(self, tokens):
        token_table = self._lexicon.token_table
        return self._sp.decode([token_table[idx] for idx in tokens])

    def decode(self, model, encoder_out, encoder_out_lens):
        pred = list()
        hyp_tokens = greedy_search_batch(
        for i in range(encoder_out.size(0)):
        return pred

class Asr:

    def __init__(self, model_dir):
        model_dir = Path(model_dir)
        self._device = torch.device('cpu')
        self._feat_extractor = self.__get_feat_extractor()
        self._model = self._init_model(model_dir)
        self._decoder = Decoder(model_dir / 'lang',)
        _LOGGER.info('Asr Model init succeed.')

    def __get_feat_extractor(self):
        opts = kaldifeat.FbankOptions()
        opts.device = self._device
        opts.frame_opts.dither = 0
        opts.frame_opts.snip_edges = False
        opts.frame_opts.samp_freq = 8000
        opts.mel_opts.num_bins = 80
        feat_extractor = kaldifeat.Fbank(opts)
        _LOGGER.info('Fbank feat extractor init succeed.')
        return feat_extractor

    def _init_model(self, model_dir):
        model = torch.jit.load(model_dir / 'model.pt')
        _LOGGER.info('ZipFormer Load succeed.')
        return model

    def _read_wav_files(wav_files):
        waves = list()
        for wav_file in wav_files:
            wave, sample_rate = torchaudio.load(wav_file)
        return waves

    def _get_feature(self, wav_files):
        waves = self._read_wav_files(wav_files)
        waves = [w.to(self._device) for w in waves]

        features = self._feat_extractor(waves)
        features = pad_sequence(

        feature_lengths = torch.tensor(
            [f.size(0) for f in features], device=self._device
        return features, feature_lengths

    def _encode(self, features, feature_lengths):

            features: 特征.
            feature_lengths: 特征长度.

        encoder_out, encoder_out_lens = self._model.encoder(
        return encoder_out, encoder_out_lens

    def recognize(self, wav_files):
        features, feature_lengths = self._get_feature(wav_files)
        encoder_out, encoder_out_lens = self._encode(features, feature_lengths)
        texts = self._decoder.decode(self._model, encoder_out, encoder_out_lens)
        return texts

def _main():
    parser = ArgumentParser('recognize')
    parser.add_argument('model_dir', type=Path, help='model directory')
    parser.add_argument('wav_dir', type=Path, help='wav directory')
    args = parser.parse_args()

    asr = Asr(args.model_dir)
    for wav_file in args.wav_dir.iterdir():
        duration = AudioSegment.from_file(wav_file).duration_seconds
        text = asr.recognize([wav_file])[0]
        _LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')

if __name__ == '__main__':
    LOGGER_FORMAT = ('%(asctime)s.%(msecs)03d - %(name)s:%(lineno)s '
                     '- %(funcName)s() - %(levelname)s - %(message)s')
    logging.basicConfig(format=LOGGER_FORMAT, level=logging.INFO)
csukuangfj commented 3 months ago

Does the memory grow linearly from 0 to 2.1GB and then keep at 2.1 GB?

_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')

Could you give the output of the above log? What is the max value of duration?

tz301 commented 3 months ago

Does the memory grow linearly from 0 to 2.1GB and then keep at 2.1 GB?

_LOGGER.info(f'[{wav_file}]: [{duration:.3f}s] [{text}]')

Could you give the output of the above log? What is the max value of duration?

The max duration is 20s in my wav files.

The memory first grow to around 1.5G for the first few wavs, then grow slowly to 2.1G and keep at 2.1G.

The first few wavs (around 5 wavs) is extremely slow, may cost 1~2 minute to finish decode. I'm not sure if it's normal that the warm up for this asr model need this time.