bringtree / question_embedding

这个仓库的issues里记录了许多奇奇怪怪的东西(100+)。
1 stars 1 forks source link

multi-mel gan 用MNN推理 #197

Open bringtree opened 3 years ago

bringtree commented 3 years ago

step 1 torch2onnx:

#!/ssd4/exec/huangps/anaconda3/envs/melgan/bin/python

import torch
import torchvision
import numpy as np
from model.generator import Generator
from utils.hparams import HParam, load_hparam_str
from utils.pqmf import PQMF
import wave

checkpoint = torch.load('./chkpt/hps/hps_13efcb4_0600.pt')
hp = load_hparam_str(checkpoint['hp_str'])

vocoder = Generator(hp.audio.n_mel_channels, hp.model.n_residual_layers,
                        ratios=hp.model.generator_ratio, mult = hp.model.mult,
                        out_band = hp.model.out_channels).cuda()
vocoder.load_state_dict(checkpoint['model_g'])
vocoder.eval(inference=False)

# vocoder.inference(mel)

mel = np.load("/ssd5/exec/huangps/melgan/datasets/LJSpeech-1.1/mels/LJ001-0001.npy")
mel = torch.from_numpy(mel).to(device='cuda', dtype=torch.float32)
mel = mel.unsqueeze(0)
dummy_input = mel
input_names = [ "mel" ]
output_names = [ "output" ]

dynamic_axes = {
    "mel" : {0: "batch_size", 2: "seq_len"}
}

torch.onnx.export(vocoder, dummy_input, "melgan.onnx", verbose=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)

# MAX_WAV_VALUE = 32768.0
# with torch.no_grad():
#     mel = mel.detach()
#     if len(mel.shape) == 2:
#         mel = mel.unsqueeze(0)
#     mel = mel.cuda()
#     audio = vocoder.inference(mel)
#     # For multi-band inference
#     if hp.model.out_channels > 1:
#         pqmf = PQMF()
#         audio = pqmf.synthesis(audio).view(-1)

#     audio = audio.squeeze() # collapse all dimension except time axis
#     audio = audio[:-(hp.audio.hop_length*10)]
#     audio = MAX_WAV_VALUE * audio
#     audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
#     audio = audio.short()
#     audio = audio.cpu().detach().numpy()

# print(audio.shape)
# print(audio[:10])
# with wave.open('1.wav', 'wb') as wavfile:
#     wavfile.setparams((1, 2, 22050, 0, 'NONE', 'NONE'))
#     wavfile.writeframes(audio)
bringtree commented 3 years ago

step2 验证onnx

import onnxruntime
import numpy as np
import torch
import wave
import torchvision
from utils.hparams import HParam, load_hparam_str
from utils.pqmf import PQMF

# checkpoint = torch.load('./chkpt/hps/hps_13efcb4_0600.pt')
# hp = load_hparam_str(checkpoint['hp_str'])

sess = onnxruntime.InferenceSession('./melgan.onnx', None)

input_names = [ "mel" ]
output_names = [ "output" ]

mel = np.load("/ssd5/exec/huangps/melgan/datasets/LJSpeech-1.1/mels/LJ001-0001.npy")

mel = mel.reshape([1,80,-1])
audio = sess.run(output_names, {'mel': mel})
print(audio)
# audio = torch.from_numpy(audio[0]).to(device='cpu', dtype=torch.float32)

# MAX_WAV_VALUE = 32768.0
# with torch.no_grad():

#     pqmf = PQMF()
#     audio = pqmf.synthesis(audio).view(-1)
#     audio = audio.squeeze()

#     audio = audio[:-(256*10)]
#     audio = MAX_WAV_VALUE * audio
#     audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
#     audio = audio.short()
#     audio = audio.cpu().detach().numpy()

# print(audio.shape)
# print(audio[:10])
# with wave.open('1.wav', 'wb') as wavfile:
#     wavfile.setparams((1, 2, 22050, 0, 'NONE', 'NONE'))
#     wavfile.writeframes(audio)
bringtree commented 3 years ago

step3 导出MNN模型

 ./MNNConvert -f ONNX --modelFile melgan.onnx --MNNModel melgan.mnn --bizCode biz

numpy 转bin

https://blog.csdn.net/guyuealian/article/details/106422400
bringtree commented 3 years ago

step4 C++推理验证

//
//  vocoder.cpp
//  MNN
//
//

#include <math.h>
#include <fstream>
#include <iostream>
#include <memory>
#include <MNN/Interpreter.hpp>

#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>

using namespace MNN;

int main(int argc, char *argv[]) {

    const auto melganModel = "/Users/peisonghuang/MNN/demo/model/melgan.mnn";
    const auto inputFileName = "/Users/peisonghuang/MNN/demo/model/LJ001-0001.bin";

    // create net and session
    auto mnnNet = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(melganModel));

    MNN::ScheduleConfig netConfig;
    netConfig.type = MNN_FORWARD_CPU;
    netConfig.numThread = 1;
    auto session = mnnNet->createSession(netConfig);
    auto input = mnnNet->getSessionInput(session, "mel");
    mnnNet->resizeTensor(input, {1, 80, 832});
    mnnNet->resizeSession(session);

//     read data from bin 80 832
    {
        MNN::Tensor givenTensor(input, Tensor::CAFFE);

        std::ifstream inputFile(inputFileName, std::ios::in | std::ios::binary);

        float fnum[80][832] = {0};
        inputFile.read((char *) &fnum, sizeof fnum);
        inputFile.close();

        for (int i = 0; i < 80; i++) {
            for (int j = 0; j < 832; j++) {
                givenTensor.host<float>()[i * 832 + j] = static_cast<float_t>(fnum[i][j]);
            }
        }
        input->copyFromHostTensor(&givenTensor);
    }

    // run...
    {
        AUTOTIME;
        mnnNet->runSession(session);
    }

    // get output
    {
        auto outputTensor = mnnNet->getSessionOutput(session, "output");
        auto nchwTensor = new Tensor(outputTensor, Tensor::CAFFE);
        outputTensor->copyToHostTensor(nchwTensor);
        for (int i = 0; i < 4; i++){
            for( int j = 0; j < 3; j++)
                std::cout << nchwTensor->host<float>()[i*53248+j] << " ";
            std::cout << std::endl;
        }

    }
    return 0;
}