suno-ai / bark

🔊 Text-Prompted Generative Audio Model
MIT License
34.89k stars 4.09k forks source link

Hugging face optimized implementation fails. #468

Open kamyarkarimi opened 10 months ago

kamyarkarimi commented 10 months ago

I am following https://huggingface.co/docs/transformers/model_doc/bark#combining-optimizaton-techniques. I am on a single gpu ml.g5.16xlarge Linux machine on AWS Sagemekr (specs: https://aws.amazon.com/ec2/instance-types/g5/) . I have already installed the bark from its git repo, and have it working with no issues on this machine. The issue comes in when I try to implement this optimized form described, using transformers (instead of from bark import SAMPLE_RATE, generate_audio, preload_models). I have exactly copied and paste the code, and get the following error, when I do the audio_array = model.generate(**inputs)

my code:

from transformers import AutoProcessor, BarkModel
from IPython.display import Audio
import torch
from optimum.bettertransformer import BetterTransformer
device = "cuda" if torch.cuda.is_available() else "cpu"

# load in fp16
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16).to(device)

# convert to bettertransformer
model = BetterTransformer.transform(model, keep_original_model=False)

# enable CPU offload
model.enable_cpu_offload()

processor = AutoProcessor.from_pretrained("suno/bark")
# model = BarkModel.from_pretrained("suno/bark")

voice_preset = "v2/en_speaker_6"

inputs = processor("Hi, my name is Frank", voice_preset=voice_preset)
audio_array = model.generate(**inputs)
audio_array = audio_array.cpu().numpy().squeeze()

the error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

Traceback:

RuntimeError                              Traceback (most recent call last)
Cell In[4], line 1
----> 1 audio_array = model.generate(**inputs)
      2 audio_array = audio_array.cpu().numpy().squeeze()

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/models/bark/modeling_bark.py:1591, in BarkModel.generate(self, input_ids, history_prompt, **kwargs)
   1583 semantic_output = self.semantic.generate(
   1584     input_ids,
   1585     history_prompt=history_prompt,
   1586     semantic_generation_config=semantic_generation_config,
   1587     **kwargs_semantic,
   1588 )
   1590 # 2. Generate from the coarse model
-> 1591 coarse_output = self.coarse_acoustics.generate(
   1592     semantic_output,
   1593     history_prompt=history_prompt,
   1594     semantic_generation_config=semantic_generation_config,
   1595     coarse_generation_config=coarse_generation_config,
   1596     codebook_size=self.generation_config.codebook_size,
   1597     **kwargs_coarse,
   1598 )
   1600 # 3. "generate" from the fine model
   1601 output = self.fine_acoustics.generate(
   1602     coarse_output,
   1603     history_prompt=history_prompt,
   (...)
   1608     **kwargs_fine,
   1609 )

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/models/bark/modeling_bark.py:969, in BarkCoarseModel.generate(self, semantic_output, semantic_generation_config, coarse_generation_config, codebook_size, history_prompt, **kwargs)
    959 x_semantic_history, x_coarse = self.preprocess_histories(
    960     history_prompt=history_prompt,
    961     max_coarse_history=max_coarse_history,
   (...)
    965     codebook_size=codebook_size,
    966 )
    967 base_semantic_idx = x_semantic_history.shape[1]
--> 969 semantic_output = torch.hstack([x_semantic_history, semantic_output])
    971 n_window_steps = int(np.ceil(max_generated_len / sliding_window_len))
    973 total_generated_len = 0
Cazforshort commented 8 months ago

Try putting the inputs on the device too.

smithedits commented 7 months ago

Did you figure it out?

smartvnm commented 7 months ago

inputs= processor("Hi, my name is Frank", voice_preset=voice_preset).to(device) That's all you need