Standard-Intelligence / hertz-dev

first base model for full-duplex conversational audio
https://si.inc
Apache License 2.0
761 stars 48 forks source link

How to run Inference? #11

Open SuperMaximus1984 opened 12 hours ago

SuperMaximus1984 commented 12 hours ago

Config: Windows 10 with RTX4090 All requirements incl. flash-attn build - done!

Server:

(venv) D:\PythonProjects\hertz-dev>python inference_server.py
Using device: cuda
<All keys matched successfully>
<All keys matched successfully>
Loaded tokenizer state dict: _IncompatibleKeys(missing_keys=[], unexpected_keys=['recon_metric.metrics.0.window', 'encoder.res_stack.0.pad_buffer', 'encoder.res_stack.1.res_block.0.conv1.pad_buffer', 'encoder.res_stack.1.res_block.1.conv1.pad_buffer', 'encoder.res_stack.1.res_block.2.conv1.pad_buffer', 'encoder.res_stack.1.res_block.3.pad_buffer', 'encoder.res_stack.2.res_block.0.conv1.pad_buffer', 'encoder.res_stack.2.res_block.1.conv1.pad_buffer', 'encoder.res_stack.2.res_block.2.conv1.pad_buffer', 'encoder.res_stack.2.res_block.3.pad_buffer', 'encoder.res_stack.3.res_block.0.conv1.pad_buffer', 'encoder.res_stack.3.res_block.1.conv1.pad_buffer', 'encoder.res_stack.3.res_block.2.conv1.pad_buffer', 'encoder.res_stack.3.res_block.3.pad_buffer', 'encoder.res_stack.4.res_block.0.conv1.pad_buffer', 'encoder.res_stack.4.res_block.1.conv1.pad_buffer', 'encoder.res_stack.4.res_block.2.conv1.pad_buffer', 'encoder.res_stack.5.res_block.0.conv1.pad_buffer', 'encoder.res_stack.5.res_block.1.conv1.pad_buffer', 'encoder.res_stack.5.res_block.2.conv1.pad_buffer', 'encoder.res_stack.6.res_block.0.conv1.pad_buffer', 'encoder.res_stack.6.res_block.1.conv1.pad_buffer', 'encoder.res_stack.6.res_block.2.conv1.pad_buffer', 'decoder.res_stack.0.res_block.2.conv1.pad_buffer', 'decoder.res_stack.0.res_block.3.conv1.pad_buffer', 'decoder.res_stack.0.res_block.4.conv1.pad_buffer', 'decoder.res_stack.1.res_block.2.conv1.pad_buffer', 'decoder.res_stack.1.res_block.3.conv1.pad_buffer', 'decoder.res_stack.1.res_block.4.conv1.pad_buffer', 'decoder.res_stack.2.res_block.2.conv1.pad_buffer', 'decoder.res_stack.2.res_block.3.conv1.pad_buffer', 'decoder.res_stack.2.res_block.4.conv1.pad_buffer', 'decoder.res_stack.3.res_block.1.conv1.pad_buffer', 'decoder.res_stack.3.res_block.2.conv1.pad_buffer', 'decoder.res_stack.3.res_block.3.conv1.pad_buffer', 'decoder.res_stack.4.res_block.1.conv1.pad_buffer', 'decoder.res_stack.4.res_block.2.conv1.pad_buffer', 'decoder.res_stack.4.res_block.3.conv1.pad_buffer', 'decoder.res_stack.5.res_block.1.conv1.pad_buffer', 'decoder.res_stack.5.res_block.2.conv1.pad_buffer', 'decoder.res_stack.5.res_block.3.conv1.pad_buffer', 'decoder.res_stack.6.pad_buffer'])
D:\PythonProjects\hertz-dev\transformer.py:195: UserWarning: Memory efficient kernel not used because: (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:773.)
  x = F.scaled_dot_product_attention(
D:\PythonProjects\hertz-dev\transformer.py:195: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen/native/transformers/sdp_utils_cpp.h:558.)
  x = F.scaled_dot_product_attention(
D:\PythonProjects\hertz-dev\transformer.py:195: UserWarning: Flash attention kernel not used because: (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:775.)
  x = F.scaled_dot_product_attention(
D:\PythonProjects\hertz-dev\transformer.py:195: UserWarning: Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:599.)
  x = F.scaled_dot_product_attention(
D:\PythonProjects\hertz-dev\transformer.py:195: UserWarning: CuDNN attention kernel not used because: (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:777.)
  x = F.scaled_dot_product_attention(
D:\PythonProjects\hertz-dev\transformer.py:195: UserWarning: CuDNN attention has been runtime disabled. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:528.)
  x = F.scaled_dot_product_attention(
Traceback (most recent call last):
  File "D:\PythonProjects\hertz-dev\inference_server.py", line 166, in <module>
    audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\inference_server.py", line 58, in __init__
    self.initialize_state(prompt_path)
  File "D:\PythonProjects\hertz-dev\inference_server.py", line 78, in initialize_state
    self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\model.py", line 323, in next_audio_from_audio
    next_latents = self.next_latent(latents_in, temps)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\model.py", line 333, in next_latent
    logits1, logits2 = self.forward(model_input)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\model.py", line 313, in forward
    x = layer(x, kv=self.cache[l])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\transformer.py", line 301, in forward
    h = self.attn(x, kv)
        ^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\transformer.py", line 253, in forward
    return x + self.attn(self.attn_norm(x), kv)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\venv\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\transformer.py", line 233, in forward
    return self._attend(q, k, v, kv_cache=kv)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\transformer.py", line 212, in _attend
    x = self._sdpa(q, k, v)
        ^^^^^^^^^^^^^^^^^^^
  File "D:\PythonProjects\hertz-dev\transformer.py", line 195, in _sdpa
    x = F.scaled_dot_product_attention(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: No available kernel. Aborting execution.

Any advice on how to run inference? Thank you!