2noise / ChatTTS

A generative speech model for daily dialogue.
https://2noise.com
GNU Affero General Public License v3.0
32.61k stars 3.54k forks source link

Add NPU Support #777

Closed shen-shanshan closed 1 month ago

shen-shanshan commented 1 month ago

What does this PR do?

Overview

This PR enables the users of ChatTTS to leverage the Ascend NPU for better performance in inferencing when GPU device is not available.

For more details, see: #776.

Environment

[!NOTE]

To properly install CANN, see here for more details. In addition, the version of torch-npu should match that of torch, see here for more details.

Examples

To start with, the library torch_npu should be correctly installed and imported, after which we can set torch.device("npu:0") if torch.npu.is_available() is True. Part of the codes are showed below:

ChatTTS/utils/gpu.py:

  import torch
+ import torch_npu

  def select_device(min_memory=2047, experimental=False):
      if torch.cuda.is_available():
          selected_gpu = 0
          max_free_memory = -1
          for i in range(torch.cuda.device_count()):
              props = torch.cuda.get_device_properties(i)
              free_memory = props.total_memory - torch.cuda.memory_reserved(i)
              if max_free_memory < free_memory:
                  selected_gpu = i
                  max_free_memory = free_memory
-         free_memory_mb = max_free_memory / (1024 * 1024)
-         if free_memory_mb < min_memory:
-             logger.get_logger().warning(
-                 f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
-             )
-             device = torch.device("cpu")
-         else:
-             device = torch.device(f"cuda:{selected_gpu}")
+         device = _select_device_if_available(min_memory, selected_gpu, max_free_memory, "GPU")
+     elif torch.npu.is_available():
+         """
+         Using Ascend NPU to accelerate the process of inferencing when GPU is not found.
+         """
+         selected_npu = 0
+         max_free_memory = -1
+         for i in range(torch.npu.device_count()):
+             props = torch.npu.get_device_properties(i)
+             free_memory = props.total_memory - torch.npu.memory_reserved(i)
+             if max_free_memory < free_memory:
+                 selected_npu = i
+                 max_free_memory = free_memory
+         device = _select_device_if_available(min_memory, selected_npu, max_free_memory, "NPU")
      ...
      return device

+ def _select_device_if_available(min_memory, selected_device, max_free_memory, device_type: str):
+     free_memory_mb = max_free_memory / (1024 * 1024)
+     if free_memory_mb < min_memory:
+         logger.get_logger().warning(
+             f"{device_type} {selected_device} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
+         )
+         return torch.device("cpu")
+     else:
+         device = "cuda" if device_type == "GPU" else "npu"
+         return torch.device(f"{device}:{selected_device}")

Plus, there are some other places of the codes might be adjusted, which won't be too much.

Feel free to leave comments to guide me in further improvements 😊.

Tests

This PR has passed the tests showed below:

Basic Usage Test

Normal mode:

iShot_2024-10-10_09 53 31

Stream mode:

iShot_2024-10-10_09 56 11

Logs are showed below:

(chattts) sss@xxx:~/github/ChatTTS$ python examples/web/webui.py
[+0000 20241010 01:49:47] [WARN]  WebUI  | funcs | no ffmpeg installed, use wav file output
[+0000 20241010 01:49:47] [INFO]  WebUI  | webui | loading ChatTTS model...
[+0000 20241010 01:49:47] [INFO] ChatTTS | dl | checking assets...
[+0000 20241010 01:49:51] [INFO] ChatTTS | dl | all assets are already latest.
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
[+0000 20241010 01:49:57] [INFO] ChatTTS | core | use device npu:0
/home/sss/bin/miniconda/miniconda3/envs/chattts_2/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
[+0000 20241010 01:49:57] [INFO] ChatTTS | core | vocos loaded.
[+0000 20241010 01:49:58] [INFO] ChatTTS | core | dvae loaded.
[+0000 20241010 01:49:58] [INFO] ChatTTS | core | embed loaded.
[+0000 20241010 01:49:59] [INFO] ChatTTS | core | gpt loaded.
[+0000 20241010 01:49:59] [INFO] ChatTTS | core | speaker loaded.
[+0000 20241010 01:49:59] [INFO] ChatTTS | core | decoder loaded.
[+0000 20241010 01:49:59] [INFO] ChatTTS | core | tokenizer loaded.
[+0000 20241010 01:49:59] [WARN]  WebUI  | funcs | Package nemo_text_processing not found!
[+0000 20241010 01:49:59] [WARN]  WebUI  | funcs | Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing
[+0000 20241010 01:49:59] [WARN]  WebUI  | funcs | Package WeTextProcessing not found!
[+0000 20241010 01:49:59] [WARN]  WebUI  | funcs | Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing
[+0000 20241010 01:49:59] [INFO]  WebUI  | webui | Models loaded successfully.
Running on local URL:  http://0.0.0.0:8080

To create a public link, set `share=True` in `launch()`.
/home/sss/bin/miniconda/miniconda3/envs/chattts_2/lib/python3.10/site-packages/numba/cpython/hashing.py:482: UserWarning: FNV hashing is not implemented in Numba. See PEP 456 https://www.python.org/dev/peps/pep-0456/ for rationale over not using FNV. Numba will continue to work, but hashes for built in types will be computed using siphash24. This will permit e.g. dictionaries to continue to behave as expected, however anything relying on the value of the hash opposed to hash as a derived property is likely to not work as expected.
  warnings.warn(msg)
text:   0%|                                                                                                                                                                   | 0/384(max) [00:00, ?it/s]/home/sss/bin/miniconda/miniconda3/envs/chattts_2/lib/python3.10/site-packages/transformers/generation/logits_process.py:476: UserWarning: AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, If you are looking for a user facing API to enable running your inference-only workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code is under risk of producing silent wrong result in some edge cases. See Note [AutoDispatchBelowAutograd] for more details. (Triggered internally at build/CMakeFiles/torch_npu.dir/compiler_depend.ts:74.)
  sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
text:   0%|▍                                                                                                                                                              | 1/384(max) [00:00,  1.63it/s]We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
text:  19%|██████████████████████████████                                                                                                                                | 73/384(max) [00:03, 21.63it/s]
code:  25%|██████████████████████████████████████▍                                                                                                                     | 504/2048(max) [00:21, 23.36it/s]
text:  19%|██████████████████████████████                                                                                                                                | 73/384(max) [00:02, 25.20it/s]
code:  25%|██████████████████████████████████████▍                                                                                                                     | 504/2048(max) [00:28, 17.97it/s]

Advanced Usage Test

Test script are showed below:

# tests/#776.py
import ChatTTS
from tools.logger import get_logger
from tools.audio import pcm_arr_to_mp3_view

import torch
import torchaudio

import os
import sys

def init():
    logger = get_logger("Test")

    now_dir = os.getcwd()
    sys.path.append(now_dir)

    chat = ChatTTS.Chat()
    chat.load(source="local", compile=False)

    # Sample a speaker from Gaussian.
    rand_spk = chat.sample_random_speaker()
    logger.info(f'rand_spk: {rand_spk}')
    params_infer_code = ChatTTS.Chat.InferCodeParams(
        spk_emb = rand_spk, 
        temperature = .3,
        top_P = 0.7,
        top_K = 20,
    )

    """
    RefineTextParams: 
    For sentence level manual control.
    use oral_(0-9), laugh_(0-2), break_(0-7) to generate special token in text to synthesize.
    """ 
    params_refine_text = ChatTTS.Chat.RefineTextParams(
        prompt='[oral_2][laugh_1][break_6]',
    )

    return logger, chat, params_infer_code, params_refine_text

def save_audio(wavs, file_name: str):
    try:
        torchaudio.save(f"{file_name}.wav", torch.from_numpy(wavs[0]).unsqueeze(0), 24000)
    except:
        torchaudio.save(f"{file_name}.wav", torch.from_numpy(wavs[0]), 24000)

def save_mp3_file(wav, tag: str):
    data = pcm_arr_to_mp3_view(wav)
    mp3_filename = f"output_audio_{tag}.mp3"
    with open(mp3_filename, "wb") as f:
        f.write(data)
    logger.info(f"Audio saved to {mp3_filename}")

if __name__ == "__main__":
    logger, chat, params_infer_code, params_refine_text = init()
    logger.info("Initializing ChatTTS ...")

    # test for sentence level manual control
    texts = ["朝辞白帝彩云间,千里江陵一日还。两岸猿声啼不住,轻舟已过万重山。"]
    logger.info("Text input: %s", str(texts))
    wavs1 = chat.infer(
        texts,
        params_refine_text=params_refine_text,
        params_infer_code=params_infer_code,
    )
    save_mp3_file(wavs1[0], "sentence_level_test")

    # test for word level manual control
    text = '朝辞白帝[uv_break]彩云间[uv_break],千里江陵[uv_break]一日还[uv_break]。两岸猿声[uv_break]啼不住[laugh],轻舟[uv_break]已过[uv_break]万重山[lbreak]。'
    wavs2 = chat.infer(
        text,
        skip_refine_text=True,
        params_refine_text=params_refine_text,
        params_infer_code=params_infer_code
    )
    save_mp3_file(wavs2[0], "words_level_test")

After running this script, we can see the audio has been correctly generated:

iShot_2024-10-10_11 25 04

Logs are showed below:

(chattts) sss@xxx:~/github/ChatTTS$ python tests/#776.py
/home/sss/bin/miniconda/miniconda3/envs/chattts_2/lib/python3.10/site-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work
  warn("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work", RuntimeWarning)
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
/home/sss/bin/miniconda/miniconda3/envs/chattts_2/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
[+0000 20241010 03:14:21] [INFO] Test | #776 | rand_spk: 蘁淰敝欀悃茸繎副涕胔欯膘癠俫磟堀襯劆淼厘嘲篐晾笌嘸佡撫穚趋贛橌旟燹届澿湮澉粰眥擮杖幔羒趬褦汍聶橊窯褺蟳敹臭庱哷圣揷囕熅戧稭嗵沪献曶帐瘯毤朳葢詿仚膛燇垕豞庙湡涜疋粅洽樊睁譃芫润翤擋淡挬扂核联祂毨绅燺嗃座姞湹厞更直尡碔庭佡椺晲嫬溑蛾嘁拳螗曌歿硆翅蛹虪儇磣毻岙瀹稯秌侞嬎癫悬兙珸衑槂榋淎蓻屠緘窌嬈毹牭玬巎莖墿窾荏祬胡巒倗璑灇丌榅筹媸淚犍嘱幈蓤蛮亚蓰傧蜔偯覧婯埲恭昁桩僳朕硈嚰婓挷卐虉嫦蒈瓖垟栰嚴搏探渡垕槅揧奬簄蠷疵螨懊來袰灄稜禫蘋畷汔謮揷诮朤炨冧汁揩殣趷搁糮苿忱翁欹灖袑好嘳叁蜏偰嫲亪斕翭琹寰蒒旁派柀比偽蔚诓搬焯懩襾秪胯栅案袋恙竰櫪榥瘎賳涋疜襙哴碄艢撉篂濪攟唤洋嘼灭离崆僞塁烵暐岥濙煴暼槓莎湟喣敐稃爙臓何傫紎樖臑嚎臻恃會楎匹濘菣佣盛侟恭嫺胆寴盙肽素襲訖咞櫵拜羵泱绀綨抱机諫覟蛯螔噺蠵槠浇淨厉虸瞲乇砷豕棢睄宴氐綑伨绹攇覈戙瓄瑅褗誅桿牮嗫曗榟埊下臐濌燒咿粪価奈体肧学棴热烯廿岫繠岉昫代樁淃尵巜嘵塷授洗葩勆玔冩廄耄啜僤褞詥跣斕謽災籉奟蘦蕛竅敶笔筗炏漺坙畅従笺睪葝恡莀腅撍哪坷癲蕱卍稢摝箇歯臋竟获潺橧譿蓤謥墲桋苓裑惕嚁獯糨衟岕弼蔄盖虬圷諴育尩膊螽繴溝舌姑茂譄帚湣勼玻腏犸椃孅孷庡徖规冪督淵愖耐虲抛聿蜷耗偌仨譝裝吐蠒虢蛜嗜菙蟏笂俯竕诋擁娫莆圙厁杚嶲撿涫灚眲此卋珑欙榼沷札誒岄琂炴栆揵坒琔奟傮匢瀔敋緽慶嘢稛詍厕莄蕖柘伲丐詌臾檔兼簝展廣泍攓睈拊糊珕芤哏汭詍痘瓜哎伀眃蠢樫諗謦牻葮潩蒺拏童壔粂賿溉笘戕挩咞譎炇窡掽嶜挬墤妔盾舙跔岇粞謖苟褵蛈跸侬梋虓扭夲絟慞眰刎宠珩星舯坧蝺硯譯佻聀祔禦萗氭绥抏攚幊珖嶀湘嚚婸场磹莣欕晼狉櫙浄櫸窎斦剑焱晵絹罣殨藭滻孮偿殮豢扒疺囐冴蚄蛀栛斪糑碓竼璵蘍羴橕噃墅瑆緟俌螾衚蝞蝊宨挙尷趶办哢歕翍攬纫嫢横赗浈潡衯畡塙舿象宾嫑蟆厑凙爺諼歷葁璪澢琨菆犢戶婲枕嚹凸掀憃袜犮挃捘快肆刧睦蟸戉翔歷彛棄稃譴篤廰甄訥槒捿纅粆僘歊媫嶠楮熖舖洂垻促捍浾瞛墢樘珰寋埾苅燗刺僥弒勨袟憛莫跗蟈緥拴勌婃扞灓腍蚧純箑绢撻咥蠄晜澃理曤旓伒瓦瘳賦煳杤俍蘽相脚袗孑摧滍媺凞晲衪瀗碙湟薨氤脡禶褴膲剝譌勣冉榥嶿橄檨冾攢澻标埘料毀烿檧攩葃螎婆呭烕嶲殐孜岺喧祶嘏蘀㴅
[+0000 20241010 03:14:21] [INFO] Test | #776 | Initializing ChatTTS ...
[+0000 20241010 03:14:21] [INFO] Test | #776 | Text input: ['朝辞白帝彩云间,千里江陵一日还。两岸猿声啼不住,轻舟已过万重山。']
/home/sss/bin/miniconda/miniconda3/envs/chattts_2/lib/python3.10/site-packages/numba/cpython/hashing.py:482: UserWarning: FNV hashing is not implemented in Numba. See PEP 456 https://www.python.org/dev/peps/pep-0456/ for rationale over not using FNV. Numba will continue to work, but hashes for built in types will be computed using siphash24. This will permit e.g. dictionaries to continue to behave as expected, however anything relying on the value of the hash opposed to hash as a derived property is likely to not work as expected.
  warnings.warn(msg)
text:   0%|                                                                                                                                                                     | 0/384(max) [00:00, ?it/s]/home/sss/bin/miniconda/miniconda3/envs/chattts_2/lib/python3.10/site-packages/transformers/generation/logits_process.py:476: UserWarning: AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, If you are looking for a user facing API to enable running your inference-only workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code is under risk of producing silent wrong result in some edge cases. See Note [AutoDispatchBelowAutograd] for more details. (Triggered internally at build/CMakeFiles/torch_npu.dir/compiler_depend.ts:74.)
  sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
text:   0%|▍                                                                                                                                                                | 1/384(max) [00:00,  1.60it/s]We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
text:  10%|███████████████▍                                                                                                                                                | 37/384(max) [00:02, 17.59it/s]
code:  15%|████████████████████████▍                                                                                                                                     | 316/2048(max) [00:13, 23.98it/s]
[+0000 20241010 03:14:51] [INFO] Test | #776 | Audio saved to output_audio_sentence_level_test.mp3
code:   4%|██████▌                                                                                                                                                        | 84/2048(max) [00:03, 25.03i
...
code:  16%|█████████████████████████▌                                                                                                                                    | 332/2048(max) [00:13, 24.42it/s]
[+0000 20241010 03:15:06] [INFO] Test | #776 | Audio saved to output_audio_words_level_test.mp3