Plachtaa / VALL-E-X

An open source implementation of Microsoft's VALL-E X zero-shot TTS model. Demo is available in https://plachtaa.github.io/vallex/
MIT License
7.58k stars 756 forks source link

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not tuple #172

Open ElinLiu0 opened 6 months ago

ElinLiu0 commented 6 months ago

Hi there,i'm trying to export VALL-E model into ONNX. To achieve this goal,i have done these things: First,i extract the promptlang2id() and textlang2id() functions like below:

import numpy as np
import torch
from typing import List,Union
language_ID = {
    'en': 0,
    'zh': 1,
    'ja': 2,
}

def promptlang2id(prompt_lang:str):
    return torch.LongTensor(
        np.array([language_ID[prompt_lang]])).to(
        torch.device('cuda')
    )

def textlang2id(text_lang:Union[List,str]):
    if isinstance(text_lang, str):
        return torch.LongTensor(
            np.array([language_ID[text_lang]])).to(
            torch.device('cuda')
        )
    elif isinstance(text_lang,List):
        return torch.LongTensor(
            np.array([language_ID[tl] for tl in text_lang])).to(
            torch.device('cuda')
        )

Second,to let VALL-E model more focus on Tensor Processing,i modified the inference() function of VALL-E model like below to make it only accept numerical parameters:

def inference(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: torch.Tensor,
        enroll_x_lens: torch.Tensor,
        top_k: int = -100,
        temperature: float = 1.0,
        prompt_language_id: torch.Tensor = None,
        text_language_id: torch.Tensor = None,
        best_of: int = 1,
        length_penalty: float = 1.0,
        return_worst: bool = False,
) -> torch.Tensor

Third,i have write an ONNX exportation script like:

import torch
import sys

sys.path.append('../')

import platform
import pathlib

if platform.system().lower() == 'windows':
    temp = pathlib.PosixPath
    pathlib.PosixPath = pathlib.WindowsPath
else:
    temp = pathlib.WindowsPath
    pathlib.WindowsPath = pathlib.PosixPath

from models.vallex import VALLE
from macros import *
import random

model = VALLE(
        N_DIM,
        NUM_HEAD,
        NUM_LAYERS,
        norm_first=True,
        add_prenet=False,
        prefix_mode=PREFIX_MODE,
        share_embedding=True,
        nar_scale_factor=1.0,
        prepend_bos=True,
        num_quantizers=NUM_QUANTIZERS,
)
checkpoint = torch.load("../checkpoints/vallex-checkpoint.pt", map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
    checkpoint["model"], strict=True
)
assert not missing_keys

model.forward = model.inference

x = torch.randint(0,1000,size=(1,158),dtype=torch.int64)
x_lens = torch.tensor([158],dtype=torch.int32)
y = torch.randint(0,2000,size=(1,728,8),dtype=torch.int32)
enroll_x_lens = 140
prompt_language_id = torch.tensor([2],dtype=torch.int64)
text_language_id = torch.randint(0,2,size=(1,18),dtype=torch.int64)
best_of = 1,
length_penalty = 1.0,
return_worst = False,

dynamic_axes = {
    'prompt_language_id': {1: 'num_channels'},
    'text_language_id': {1: 'num_channels'},
}

torch.onnx.export(
    model=model,
    args = (
        x,
        x_lens,
        y,
        enroll_x_lens,
        prompt_language_id,
        text_language_id,
        best_of,
        length_penalty,
        return_worst,
    ),
    f = "vallex.onnx",
    input_names = [
        'x',
        'x_lens',
        'y',
        'enroll_x_lens',
        'prompt_language_id',
        'text_language_id',
        'best_of',
        'length_penalty',
        'return_worst',
    ],
    output_names=[
        "codes"
    ],
    dynamic_axes=dynamic_axes
)

So here is the all reproduce steps,it raise the error:

File "/home/elin/anaconda3/envs/valle/lib/python3.10/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not tuple

Is there any wrong with my sample input Tensors or something?