artyom-beilis / pytorch_dlprim

DLPrimitives/OpenCL out of tree backend for pytorch
http://blog.dlprimitives.org/
MIT License
227 stars 16 forks source link

The operator 'aten::arange.start_out' is not currently supported on the ocl backend. #43

Open leviathanch opened 8 months ago

leviathanch commented 8 months ago

When I try to run VITS I get this:

Accessing device #0:gfx900:xnack- on AMD Accelerated Parallel Processing Text splitted to sentences. ['Hello world!'] Hello world! [!] Character 'H' not found in the vocabulary. Discarding it. Hello world! [!] Character 'e' not found in the vocabulary. Discarding it. Hello world! [!] Character 'l' not found in the vocabulary. Discarding it. Hello world! [!] Character 'o' not found in the vocabulary. Discarding it. Hello world! [!] Character 'w' not found in the vocabulary. Discarding it. Hello world! [!] Character 'r' not found in the vocabulary. Discarding it. Hello world! [!] Character 'd' not found in the vocabulary. Discarding it. /home/leviathan/.local/lib/python3.10/site-packages/torch/nn/functional.py:2233: UserWarning: The operator 'aten::index_select' is not currently supported on the ocl backend. Please open an issue at for requesting support https://github.com/artyom-beilis/pytorch_dlprim/issues (Triggered internally at /home/leviathan/Desktop/YouTube/AI/pytorch_dlprim/src/tensor_ops.cpp:311.) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) /home/leviathan/.local/lib/python3.10/site-packages/TTS/tts/utils/helpers.py:55: UserWarning: The operator 'aten::arange.start_out' is not currently supported on the ocl backend. Please open an issue at for requesting support https://github.com/artyom-beilis/pytorch_dlprim/issues (Triggered internally at /home/leviathan/Desktop/YouTube/AI/pytorch_dlprim/src/tensor_ops.cpp:311.) seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) Traceback (most recent call last): File "/home/leviathan/TTSTest/test1.py", line 15, in wav = tts.tts(text="Hello world!", speaker_wav="annitta.wav") File "/home/leviathan/.local/lib/python3.10/site-packages/TTS/api.py", line 341, in tts wav = self.synthesizer.tts( File "/home/leviathan/.local/lib/python3.10/site-packages/TTS/utils/synthesizer.py", line 386, in tts outputs = synthesis( File "/home/leviathan/.local/lib/python3.10/site-packages/TTS/tts/utils/synthesis.py", line 221, in synthesis outputs = run_model_torch( File "/home/leviathan/.local/lib/python3.10/site-packages/TTS/tts/utils/synthesis.py", line 53, in run_model_torch outputs = _func( File "/home/leviathan/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/leviathan/.local/lib/python3.10/site-packages/TTS/tts/models/vits.py", line 1124, in inference x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) File "/home/leviathan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/leviathan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/leviathan/.local/lib/python3.10/site-packages/TTS/tts/layers/vits/networks.py", line 94, in forward x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] File "/home/leviathan/.local/lib/python3.10/site-packages/TTS/tts/utils/helpers.py", line 55, in sequence_mask seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) RuntimeError: Buffer is not valid for unallocated defvice

artyom-beilis commented 8 months ago

Interesting.

RuntimeError: Buffer is not valid for unallocated defvice

This is actual problem. Something fails in cpu fallback. I need to check why it fails.

In any case I'm working on layer_norm now to enable VITS in general - they currently don't work.

Also I must say lots of stuff is still missing and I hadn't tested anything related to NLP (mostly due to lack of experience in that area)

leviathanch commented 8 months ago

That's the code I've been running, in case it helps:

import torch
torch.ops.load_library("/usr/local/lib/libpt_ocl.so")
torch.ops.load_library("/usr/local/lib/libdlprim_core.so")
torch.utils.rename_privateuse1_backend('ocl')

from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits
from TTS.api import TTS

print(TTS().list_models())
tts = TTS("tts_models/bn/custom/vits-female").to("ocl:0")

wav = tts.tts(text="Hello world!", speaker_wav="annitta.wav")
tts.tts_to_file(text="Hello world!", speaker_wav="annitta.wav", file_path="output.wav")