Sharpness1i / dittts

0 stars 0 forks source link

dit_net #1

Open Sharpness1i opened 1 month ago

Sharpness1i commented 1 month ago

import logging import math import random from dataclasses import dataclass, field, fields import dataclasses from math import pi from typing import Sequence, Tuple, Union, Optional

import torch from einops import rearrange, reduce, repeat from torch import Tensor, nn from torch.nn import functional as F from tqdm import tqdm import numpy as np

from fish_speech.models.dit_llama.models.commons.align_ops import expand_states from fish_speech.models.dit_llama.models.commons.llama import LLaMa from fish_speech.models.dit_llama.models.commons.vc_modules import ConvGlobalStacks import torchdiffeq from diffusers import DDPMScheduler, DDIMScheduler

logger = logging.getLogger(name)

class NumberEmbedder(nn.Module): def init(self, features: int, dim: int = 256): super().init() assert dim % 2 == 0, f"dim must be divisible by 2, found {dim}" self.features = features self.weights = nn.Parameter(torch.randn(dim // 2)) self.to_out = nn.Linear(in_features=dim + 1, out_features=features)

def to_embedding(self, x: Tensor) -> Tensor:
    x = rearrange(x, "b -> b 1")
    freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
    fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
    fouriered = torch.cat((x, fouriered), dim=-1)
    return self.to_out(fouriered)

def forward(self, x: Union[Sequence[float], Tensor]) -> Tensor:
    if not torch.is_tensor(x):
        x = torch.tensor(x, device=self.weights.device)
    assert isinstance(x, Tensor)
    shape = x.shape
    x = rearrange(x, "... -> (...)")
    return self.to_embedding(x).view(*shape, self.features)  # type: ignore

class RMSNorm(nn.Module): def init(self, dim, feat_dim=-1, eps=1e-5): super().init() self.rms = dim**-0.5 self.feat_dim = feat_dim self.eps = eps self.scale = nn.Parameter(torch.ones(dim))

def forward(self, x, unscaled=False):
    norm = torch.norm(x, dim=self.feat_dim, keepdim=True) * self.rms
    if unscaled:
        return x / norm.clamp(min=self.eps)
    g = self.scale
    if self.feat_dim != -1:
        while g.ndim <= self.feat_dim:
            g = g[None]
        while g.ndim < x.ndim:
            g = g.unsqueeze(-1)
    return x / norm.clamp(min=self.eps) * g

"""For VDiffusion"""

class Distribution: """Interface used by different distributions"""

def __call__(self, num_samples: int, device: torch.device):
    raise NotImplementedError()

class UniformDistribution(Distribution): def init(self, vmin: float = 0.0, vmax: float = 1.0): super().init() self.vmin, self.vmax = vmin, vmax

def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
    vmax, vmin = self.vmax, self.vmin
    return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin

class LogitNormalDistribution(Distribution): def init(self, mean: float = 0.0, std: float = 1.0): super().init() self.mean, self.std = mean, std

def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
    x = torch.from_numpy(np.random.lognormal(self.mean, self.std, num_samples)).to(device)
    return x / (1 + x)

class BernoulliDistribution(Distribution): def init(self, v1, v2): super().init() self.map = torch.tensor([v1, v2]).unsqueeze(0)

def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
    index = (torch.rand(num_samples, device=device) > 0.5).long()
    return self.map.repeat(num_samples, 1).to(device)[
        torch.arange(num_samples), index
    ]

def extend_dim(x: Tensor, dim: int):

e.g. if dim = 4: shape [b] => [b, 1, 1, 1],

return x.view(*x.shape + (1,) * (dim - x.ndim))

def Ts(t): """Builds a type template for a given type that accepts a list of instances""" return lambda types: lambda: t([tp() for tp in types])

class Sequential(nn.Module): """Custom Sequential that includes all args"""

def __init__(self, *blocks):
    super().__init__()
    self.blocks = nn.ModuleList(blocks)

def forward(self, x: Tensor, *args) -> Tensor:
    for block in self.blocks:
        x = block(x, *args)
    return x

def Repeat(m, times: int): ms = (m,) times return Sequential(ms) if isinstance(m, nn.Module) else Ts(Sequential)(*ms)

class TimeEmbedding(nn.Module): def init(self, modulation_features, num_layers: int = 2, bias=True): super().init() self.embedding = NumberEmbedder(features=modulation_features) self.mlp = Repeat( nn.Sequential( nn.Linear(modulation_features, modulation_features, bias=bias), nn.GELU(), ), times=num_layers, )

def forward(self, time):
    # Process time to time_features
    time_features = F.gelu(self.embedding(time))
    time_features = self.mlp(time_features)
    # Overlap features if more than one per batch
    if time_features.ndim == 3:
        time_features = reduce(time_features, "b n d -> b d", "sum")

    return time_features

def get_sigma(t: Tensor): angle = t * math.pi / 2 sigma = torch.tan(angle) return sigma

@dataclass class ModelArgs:

frontend

phone_embed_dim: int = 512
tone_embed_dim: int = 128
#n_phone: int = 300 + 2
n_phone: int = 74 + 2

n_tone: int = 30 + 2

local_cond_dim: int = 512
time_embed_dim: int = 256

local_cond_project_type: str = "linear"  # conv
local_cond_conv_kernel: int = 9
local_cond_conv_padding: int = 4

# llama
encoder_dim: int = 1024
encoder_n_layers: int = 24
encoder_n_heads: int = 16
encoder_n_kv_heads: int = None
mlp_extend: float = None
max_seq_len: int = 8192
multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5
dropout: float = 0.0
ffn_dim_multiplier: Optional[float] = None
use_causal_attn: bool = False

causal: bool = False
use_qk_norm: str = ""  # head, channel
use_window_mask: bool = False
window_size: list = field(default_factory=lambda: [-1, -1])
window_type: str = "elemwise"  # elemwise, blockwise
llama_provider: str = "ctiga"

# speaker encoder
spk_e_dim: int = 1024
spk_embed_dim: int = 512

# postnet
postnet_type: str = "linear"  # conv
postnet_kernel: int = 3

target: str = "bn"
prompt_feature: str = "bn"
in_channels: int = 24
out_channels: int = 24
#in_channels: int = 8
#out_channels: int = 8

use_textprefix: bool = True

bias: bool = False
target_type: str = "vector_field" 
#target_type: str = "epsilon"

# for uniform t
min_t: float = 0.0
max_t: float = 1.0
lognormal_mean: float = 0.0
lognormal_std: float = 1.0

p_max_t: float = 0.0

use_seg_embed: bool = True
use_bn_eos_bos: bool = True

max_phone_len: int = 2000
max_bn_len: int = 4000

flashattn_version: str = "2.3"

use_expand_ph: bool = False
zero_xt_prompt: bool = False

use_causal_attn: bool = False

use_qk_norm: bool = False
use_cache: bool = False
max_cache_batch_size: int = 10
use_bpe: bool=False

class Diffusion(nn.Module): def init(self, hp): super().init() self.hp = hp

    self.min_t = hp.min_t if hasattr(hp, "min_t") else 0
    self.max_t = hp.max_t if hasattr(hp, "max_t") else 1

    self.target_type = hp.target_type if hasattr(hp, "target_type") else "vector_field"

    self.act_fn = nn.GELU()
    self.bias = hp.bias

    # text.
    self.ph_proj = nn.Sequential(
        nn.Embedding(hp.n_phone, hp.phone_embed_dim, padding_idx=0),
        nn.Linear(hp.phone_embed_dim, hp.encoder_dim)

    )

    # time-embedding
    #self.time_embedding = TimeEmbedding(hp.time_embed_dim, bias=self.bias)
    from fish_speech.models.dit_llama.models.commons.time_embedding import DDPMTimestepEmbedder, CFMTimeEmbedding
    self.time_embedding = CFMTimeEmbedding(hp.time_embed_dim, bias=self.bias)
    if self.target_type == 'epsilon':
        self.time_embedding = DDPMTimestepEmbedder(hp.time_embed_dim)

    self.in_channels = hp.mel_dim
    self.out_channels = hp.mel_dim
    # global speaker-embedding
    self.prompt_encoder = ConvGlobalStacks(
            idim=self.in_channels, n_chans=hp.spk_e_dim, odim=hp.spk_embed_dim)

    local_cond_in_channels = self.out_channels + hp.spk_embed_dim

    self.local_cond_project = nn.Linear(
        local_cond_in_channels, hp.local_cond_dim, bias=self.bias)

    if not hasattr(hp, "window_size"):
        hp.window_size = [-1, -1]

    self.encoder = LLaMa(hp)

    self.x_prenet = nn.Linear(self.in_channels, hp.encoder_dim, bias=self.bias)
    self.prenet = nn.Linear(
        hp.time_embed_dim + hp.local_cond_dim, hp.encoder_dim, bias=self.bias
    )
    if hp.use_expand_ph:
        self.expand_ph_prenet = nn.Linear(
            hp.encoder_dim, hp.encoder_dim, bias=self.bias
        )

    self.postnet = nn.Linear(hp.encoder_dim, self.out_channels, bias=False)

    #self.sigma_distribution = UniformDistribution(vmin=self.min_t, vmax=self.max_t)

    self.use_seg_embed = hp.use_seg_embed
    if hp.use_seg_embed:
        self.seg_embed = nn.Embedding(3, hp.encoder_dim, padding_idx=0)
        nn.init.trunc_normal_(self.seg_embed.weight, std=0.02, a=-0.04, b=0.04)

    if hp.use_bn_eos_bos:
        self.bn_eos_bos = nn.Parameter(torch.randn(2, hp.encoder_dim))  

    if self.target_type == 'epsilon':
        self.noise_scheduler = DDPMScheduler(
            beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", 
            num_train_timesteps=1000, rescale_betas_zero_snr=False, clip_sample=False)
    elif self.target_type == 'vector_field':
        ''' sigma==0 means ODE'''
        #from modules.tts.ps2.flow_matching.vp_cfm import VariancePreservingConditionalFlowMatcher
        from torchcfm import VariancePreservingConditionalFlowMatcher
        self.flow_matcher = VariancePreservingConditionalFlowMatcher(sigma=0.0)
    else:
        raise NotImplementedError

def forward(self, inputs, sigmas=None, x_noisy=None):
    """ text projection """
    #ph_embed = self.ph_proj(inputs["phone"]) + self.tone_proj(inputs["tone"])  # [B, T, 1024]
    ph_embed = self.ph_proj(inputs["phone"])  # [B, T, 1024]
    ph_lens, feat_lens = inputs["text_lens"], inputs["lat_lens"]  #[B]

    ctx_feature = inputs['lat'] * inputs['ctx_mask']    #[B,T,C]
    B, device = ctx_feature.size(0), ctx_feature.device

    """ speaker embedding (global speaker reference) """
    spk_emb = self.prompt_encoder(ctx_feature)          #[B,C=512]     
    spk_emb = spk_emb * (1 - inputs["spk_cfg_mask"])
    spk_emb = spk_emb.unsqueeze(1).expand(-1, ctx_feature.shape[1], -1) #[B,T,C=512] (Repet To Latent Time Dim)

    """ local conditioning (prompt_latent + spk_embed) """
    ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:,:,None])
    local_cond = torch.cat([ctx_feature, spk_emb], dim=-1)
    local_cond = self.local_cond_project(local_cond)

    """ diffusion target latent """
    x = inputs['lat']

    if self.target_type == "epsilon":
        noise = torch.randn_like(x)
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, (B,), device=x.device,
            dtype=torch.int64
        )
        # time embedding
        #t = timesteps.float() / self.noise_scheduler.config.num_train_timesteps
        #time_emb = self.time_embedding(t)
        time_emb = self.time_embedding(timesteps)
        time_emb = time_emb.unsqueeze(1).expand(-1, local_cond.shape[1], -1)
        # define noisy_input and target
        x_noisy = self.noise_scheduler.add_noise(x, noise, timesteps)
        target = noise
    elif self.target_type == "vector_field":
        # Here, x is x1 in CFM
        x0 = torch.randn_like(x)
        t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
        # time embedding
        time_emb = self.time_embedding(t)
        time_emb = time_emb.unsqueeze(1).expand(-1, local_cond.shape[1], -1)
        # define noisy_input and target
        x_noisy = xt
        target = ut
    else:
        raise NotImplementedError

    # concat condition.
    if self.hp.zero_xt_prompt:
        x_noisy = x_noisy * (1 - inputs['ctx_mask'])
    x_noisy = self.x_prenet(x_noisy) + self.prenet(
        torch.cat([local_cond,time_emb], dim=-1)   #[B,T,C=1024]
    )
    if self.hp.use_expand_ph:
        rand_num = random.random()
        if rand_num <= 0.33:
            x_noisy += expand_states(ph_embed, inputs['mel2ph'])
        elif rand_num <= 0.67:
            x_noisy += expand_states(ph_embed, inputs['sparsified_dur'])
        else:
            pass

    # Add BPE tokens
    if self.hp.use_bpe:
        bpe_embed, bpe_lens = self.bpe_proj(inputs["bpe"]), inputs["bpe_lengths"]
        text_embed = add_prefix(bpe_embed, bpe_lens, ph_embed, ph_lens)
        text_lens = ph_lens + bpe_lens
    else:
        text_embed = ph_embed
        text_lens = ph_lens

    if self.hp.use_bn_eos_bos:
        inputs["text_mel_mask"] = F.pad(inputs["text_mel_mask"], (2, 0), "constant", 1)
        feat_lens = feat_lens + 2
        bn_bos = self.bn_eos_bos[0][None, None, :].expand(x_noisy.shape[0], -1, -1)
        x_noisy = torch.cat([bn_bos, x_noisy, torch.zeros_like(bn_bos)], dim=1)
        indics_x = torch.arange(x_noisy.shape[1], device=device)[None, :]
        mask_eos = (indics_x == feat_lens[:, None]-1)   #把noise中的padding部分的随机化
        x_noisy[mask_eos] = self.bn_eos_bos[1][None, None, :]
        #breakpoint()

    # concat prefix-text.
    if self.hp.use_textprefix:
        T = inputs["text_mel_mask"].shape[1]
        C = x_noisy.shape[-1]
        x_noisy_wtext = torch.full(
            [B, T, C], 0, device=device, dtype=x_noisy.dtype
        )
        if self.use_seg_embed:  #True
            seg_input = torch.zeros([B, T], device=x_noisy.device).long()

        T_text = text_embed.shape[1]
        T_feat = x_noisy.shape[1]
        indics_x = torch.arange(T, device=device)[None, :]
        mask_x = (indics_x < text_lens[:, None]) & (indics_x < T_text)      #mask-Text: Text的部分(T)都是1
        mask_text = (
            torch.arange(text_embed.shape[1], device=device)[None, :]
            < text_lens[:, None]
        )
        x_noisy_wtext[mask_x] = text_embed[mask_text].to(dtype=x_noisy_wtext.dtype)
        if self.use_seg_embed:
            seg_input[mask_x] = 1

        mask_x = (
            (text_lens[:, None] <= indics_x)
            & (indics_x < (text_lens + feat_lens)[:, None])
            & (indics_x - text_lens[:, None] < T_feat)
        )       #Padding Mask
        mask_noisy = (
            torch.arange(T_feat, device=device)[None, :] < feat_lens[:, None]
        )
        x_noisy_wtext[mask_x] = x_noisy[mask_noisy].to(dtype=x_noisy_wtext.dtype)
        if self.use_seg_embed:
            seg_input[mask_x] = 2

        encoder_input = x_noisy_wtext
        seq_mask = inputs["text_mel_mask"]
        if self.use_seg_embed:
            seg_output = self.seg_embed(seg_input)
            encoder_input = encoder_input + seg_output
    else:
        encoder_input = x_noisy
        seq_mask = inputs[f"{self.hp.prompt_feature}_mask"]

    #LLAMA Encode
    encoder_out = self.encoder(encoder_input, seq_mask)

    if self.hp.use_bn_eos_bos:
        feat_lens = feat_lens - 2

    if self.hp.use_textprefix:
        pred_v_wotext = torch.zeros(B, x.shape[1], encoder_out.shape[-1], device=device)

        T0 = x.shape[1]  ##latent
        T1 = encoder_out.shape[1]
        indics0 = torch.arange(T0, device=device)[None, :]
        indics1 = torch.arange(T1, device=device)[None, :]

        mask0 = (indics0 < feat_lens[:, None]) & (
            (text_lens[:, None] + indics0) < T1
        )
        if self.hp.use_bn_eos_bos:
            mask1 = (text_lens[:, None] < indics1) & (
                indics1 < (text_lens[:, None] + feat_lens[:, None]+1)
            )
        else:
            mask1 = (text_lens[:, None] <= indics1) & (
                indics1 < (text_lens[:, None] + feat_lens[:, None])
            )

        pred_v_wotext[mask0] = encoder_out[mask1].to(dtype=pred_v_wotext.dtype)

        pred_v = pred_v_wotext

    pred = self.postnet(pred_v)  #(B,T,C)

    ret_dict = {
        "pred_v": pred.transpose(1, 2),
        "target_v": target.transpose(1, 2),
    }
    return ret_dict

def _forward(self, x, local_cond, text_embed, bpe_embed, timesteps, ctx_mask, dur=None, cfg_w=1.0, use_seq_cfg=False, seq_cfg_w=[1.0,1.0]):
    """ When we use torchdiffeq, we need to include the CFG process inside _forward() """
    time_emb = self.time_embedding(timesteps)
    time_emb = time_emb.unsqueeze(1).expand(local_cond.shape[0], local_cond.shape[1], -1)
    if self.hp.zero_xt_prompt:
        x = x * (1 - ctx_mask)
    x = self.x_prenet(x) + self.prenet(torch.cat([local_cond,time_emb], dim=-1))

    if dur is not None:
        x += expand_states(text_embed, dur)

    # Add BPE tokens
    if self.hp.use_bpe:
        text_embed = torch.cat([bpe_embed, text_embed], dim=1)

    if self.hp.use_bn_eos_bos:
        bn_bos = self.bn_eos_bos[0][None, None, :].expand(x.shape[0], -1, -1)
        bn_eos = self.bn_eos_bos[1][None, None, :].expand(x.shape[0], -1, -1)
        x = torch.cat([bn_bos, x, bn_eos], dim=1)

    if self.hp.use_textprefix:
        x = torch.cat([text_embed, x], dim=1)
        if self.use_seg_embed:
            seg_input = torch.zeros([1, x.shape[1]], device=x.device).long()
            seg_input[:, :text_embed.shape[1]] = 1
            seg_input[:, text_embed.shape[1]:] = 2
            seg_output = self.seg_embed(seg_input)
            x = x + seg_output

    pred_v = self.encoder(x, attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))

    if self.hp.use_textprefix:
        pred_v = pred_v[:, text_embed.shape[1] :, :]

    if self.hp.use_bn_eos_bos:
        pred_v = pred_v[:, 1:-1, :]

    pred = self.postnet(pred_v)

    """ Perform CFG """
    """
    if use_torchdiffeq:
        if use_rescale:
            weight = 4
            rescale = 1
            pos, neg = pred[0:1], pred[1:2]
            # Apply regular classifier-free guidance.
            cfg = neg + weight * (pos - neg)
            # Calculate standard deviations.
            std_pos = pos.std([1,2], keepdim=True)
            std_cfg = cfg.std([1,2], keepdim=True)
            # Apply guidance rescale with fused operations.
            factor = std_pos / std_cfg
            factor = rescale * factor + (1 - rescale)
            pred = cfg * factor
        else:
            pred = torchdiffeq_cfg * pred[0:1] + (1 - torchdiffeq_cfg) * pred[1:2]
    """
    if cfg_w != 1:
        if use_seq_cfg:
            cond_spk_txt, cond_txt, uncond = pred.chunk(3)
            pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) \
                + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
            # pred = cond_spk_txt + seq_cfg_w[1]*(cond_spk_txt - cond_txt) + seq_cfg_w[0]*(cond_txt - uncond)
        else:
            cond_spk_txt, uncond = pred.chunk(2)
            pred = uncond + cfg_w * (cond_spk_txt - uncond)
    return pred

@torch.no_grad()
def inference(self, inputs, timesteps=20, cfg_w=1.0, use_seq_cfg=False, seq_cfg_w=[1.0, 1.0], **kwargs):
    #text_embed = self.ph_proj(inputs["phone"]) + self.tone_proj(inputs["tone"])  # [B, T, 1024]
    text_embed = self.ph_proj(inputs["phone"])  # [B, T, 1024]
    if self.hp.use_bpe:
        bpe_embed = self.bpe_proj(inputs['bpe'])
    else:
        bpe_embed = None

    # speaker embedding
    ctx_feature = inputs['lat_ctx']
    spk_emb = self.prompt_encoder(ctx_feature)
    if cfg_w != 1:
        spk_emb[1:, :] = 0 # enable global spk cfg
        ctx_feature[1:, :, :] = 0 # prefix spk cfg
        #breakpoint()
    spk_emb = spk_emb.unsqueeze(1).expand(-1, ctx_feature.shape[1], -1)

    # local conditioning.
    local_cond = torch.cat([ctx_feature, spk_emb], dim=-1)
    local_cond = self.local_cond_project(local_cond)

    if self.target_type == 'epsilon':
        # Build scheduler
        scheduler = DDIMScheduler.from_config(
            self.noise_scheduler.config,
            rescale_betas_zero_snr=False, clip_sample=False, set_alpha_to_one=False, thresholding=False)
        scheduler.set_timesteps(50)

        t = 50
        _, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
        x = torch.randn([1, frm_len, self.out_channels], device=device)
        x = x * scheduler.init_noise_sigma

        #sigmas = torch.linspace(self.max_t, self.min_t, t + 1, device=device)
        #sigmas = repeat(sigmas, "i -> i b", b=1)

        for i in scheduler.timesteps:
            if cfg_w != 1:
                v_pred, v_pred_uncond = self._forward(
                    x, local_cond, text_embed, bpe_embed, timesteps=torch.LongTensor([t]).to(x.device).repeat(2)
                ).chunk(2)
                v_pred = cfg_w * v_pred + (1 - cfg_w) * v_pred_uncond
            else:
                v_pred = self._forward(
                    x, local_cond, text_embed,bpe_embed, timesteps=torch.LongTensor([t]).to(x.device).repeat(2)
                )
            x = scheduler.step(v_pred, t, x).prev_sample

    elif self.target_type == 'vector_field':
        #t = timesteps
        bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))

        if cfg_w != 1:
            traj = torchdiffeq.odeint(
                lambda t, x: self._forward(torch.cat([x] * bsz), local_cond, text_embed, bpe_embed, timesteps=t.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=None, cfg_w=cfg_w, use_seq_cfg=use_seq_cfg, seq_cfg_w=seq_cfg_w),
                torch.randn([1, frm_len, self.out_channels], device=device),
                torch.linspace(0, 1, timesteps, device=device),
                atol=1e-4,
                rtol=1e-4,
                method="euler",
            )
        else:
            traj = torchdiffeq.odeint(
                lambda t, x: self._forward(x, local_cond, text_embed, bpe_embed, timesteps=t.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=None),
                torch.randn([1, frm_len, self.out_channels], device=device),
                torch.linspace(0, 1, timesteps, device=device),
                atol=1e-4,
                rtol=1e-4,
                method="dopri5",
            )
        x = traj[-1]

    else:
        raise NotImplementedError

    return x
Sharpness1i commented 1 month ago

import os import torch from torch import nn import torch.nn.functional as F from torchvision import transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader import pytorch_lightning as L from models.melae_kl_new import MelAEKL from models.mel_vqgan import Mel_VQGAN from models.speechdit import Diffusion, ModelArgs from tasks.mel_vqgan import MelVQGAN_task from utils.utils import ssim,weights_nonzero_speech from Bigvgan.bigvgan.models_bigvgan import BigVGAN as BigVGANGenerator from utils.ckpt import load_ckpt from utils.plot import spec_to_figure from utils.schedulers import WarmupSchedule import numpy as np from scipy.io import wavfile from memory_profiler import profile import gc

Vocoder Config

config = { 'model':{ 'resblock_kernel_sizes': [3, 7, 11], 'resblock_dilation_sizes': [[1,3,5],[1,3,5],[1,3,5]], 'upsample_rates': [5,2,2,2,2,2], 'upsample_initial_channel': 1024, 'upsample_kernel_sizes': [11,4,4,4,4,4] } }

class DIT_task(L.LightningModule): def init(self, hparams): super().init() self.h = hparams self.save_hyperparameters(hparams)

self.save_vae_latent = torch.empty(0, device=torch.device('cuda'))

    self.diff_model = self.build_model()

    self.ckpt_dir = 'Bigvgan/checkpoints/bigvgan_16k_librilight'
    self.vocoder = BigVGANGenerator(160, **config['model'])
    load_ckpt(self.vocoder, self.ckpt_dir, 'model_gen', silent=True)
    self.vocoder.remove_weight_norm()
    for param in self.vocoder.parameters():
        param.requires_grad = False

def build_model(self):
    self.build_tts_model()
    self.gen_params = list(self.diff_model.parameters())
    #print(self.diff_model)
    #for n, m in self.diff_model.named_children():
    #    num_params(m, model_name=n)
    #if hparams['pitch_percep_ckpt'] != '':
    #    self.build_pitch_percep_model()
    return self.diff_model

def build_tts_model(self):
    hparams = self.h
    #load_ckpt(self.vae_model, hparams['vae_path'], strict=True)

    if hparams.get('use_mel', False):
        self.vae_stride = 1
    if hparams.get('use_jzy_vae', False):
        #JZY latent
        checkpoint_path= '/home/xj_data/yangqian/TTS/lightning-framework/good-exp/jzy-vae/model_ckpt_steps_2210000.ckpt'
        self.vae_stride = 8
        self.vae_model = MelAEKL(hidden_size=256, vae_beta=1e-2, latent_dim=8, stride=8)
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda'))
        model_state_dict = checkpoint['state_dict']['model']
        self.vae_model.load_state_dict(model_state_dict)

    else:
        #VQ Latent
        checkpoint_path= '/home/xj_data/yangqian/TTS/lightning-framework/good-exp/0701-librilight-vae-gan-latent24-stride4/version_1/checkpoints/epoch=2-step=440000.ckpt'
        #self.vqvae_model = Mel_VQGAN(hidden_size=192, vae_beta=1e-2, latent_dim=8)
        self.vae_stride = hparams["vae_stride"]
        self.vae_model = MelAEKL(hidden_size=192, vae_beta=1e-2, latent_dim=24, stride=4)
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda'))
        model_state_dict = checkpoint['state_dict']
        #cleaned_state_dict = {k.replace('vqvae.', ''): v for k, v in model_state_dict.items() if k.startswith('vqvae.')}
        cleaned_state_dict = {k.replace('vae.', ''): v for k, v in model_state_dict.items() if k.startswith('vae.')}
        self.vae_model.load_state_dict(cleaned_state_dict)

    config = ModelArgs()
    config.target_type = 'epsilon' if hparams.get('use_ddpm', False) else 'velocity'
    config.target_type = 'vector_field' if hparams.get('use_vpcfm', False) else config.target_type
    config.use_expand_ph = hparams.get('use_expand_ph', False)
    config.zero_xt_prompt = hparams.get('zero_xt_prompt', False)
    self.diff_model = Diffusion(config, hparams)
    self.cfg_mask_token_phone = config.n_phone - 1
    self.cfg_mask_token_tone = config.n_tone - 1
    self.mask_r = hparams.get('mask_r', 0.5)

def run_model(self, sample, infer=False, *args, **kwargs):
    hparams = self.h
    txt_tokens = sample["txt_tokens"]  # [B, T_t]
    #tone_tokens = sample["tone"]  # [B, T_t]
    txt_lengths = sample["txt_lengths"]  # [B, T_t]
    #bpe_tokens = sample["bpe_tokens"]  # [B, T_t]
    #bpe_lengths = sample["bpe_lengths"]  # [B, T_t]
    mels = sample["mels"]  # [B, T_s, 160]
    mel_lengths = sample["mel_lengths"]  # [B, T_s]
    #mel2ph = sample["mel2ph"][:, :: self.vae_stride]  # [B, T_mel/8]
    #latent_lengths = (mel2ph > 0).sum(-1)
    #loss_mask = (mel2ph > 0).float()[:, :, None]
    mels_loss_mask = (torch.abs(mels - (-6.0)) > 1e-3).float()
    mels_one_matrix = (mels * mels_loss_mask)[:, :: self.vae_stride].sum(-1)
    loss_mask = (torch.abs(mels_one_matrix - 0) > 1e-3).float()[:, :, None]
    #torch.set_printoptions(threshold=np.inf)

    #""" Disable the English tone (set them to 3)"""
    #en_tone_idx = ~((sample["tone"] == 4) | ( (11 <= sample["tone"]) & (sample["tone"] <= 15)) | (sample["tone"] == 0))
    #sample["tone"][en_tone_idx] = 3

    with torch.inference_mode():
        with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=True):
            if hparams.get('use_mel', False):
                vae_latent = mels
            else:
                vae_latent = self.vae_model.get_latent(mels).transpose(1,2)
                #vae_latent = self.vqvae_model.get_latent(mels).transpose(1,2)

    if not infer:
        ''' Assign random CFG Mask in Training '''
        if hparams.get('use_seq_cfg', False):
            spk_cfg_mask = torch.rand_like(txt_tokens[:, 0].float())[:, None]
            spk_cfg_mask = (spk_cfg_mask < 0.20).long()

            txt_cfg_mask = torch.rand_like(txt_tokens[:, 0].float())[:, None]
            txt_cfg_mask = (txt_cfg_mask < 0.50).long() * spk_cfg_mask

            txt_tokens = txt_tokens * (1 - txt_cfg_mask) + self.cfg_mask_token_phone * txt_cfg_mask
            #tone_tokens = tone_tokens * (1 - txt_cfg_mask) + self.cfg_mask_token_tone * txt_cfg_mask
            #bpe_tokens = bpe_tokens * (1 - txt_cfg_mask) + 32006 * txt_cfg_mask

        elif hparams.get('use_cfg', False):
            cfg_mask = torch.rand_like(txt_tokens[:, 0].float())[:, None]
            cfg_mask = (cfg_mask < 0.15).long() 
            txt_tokens = txt_tokens * (1 - cfg_mask) + self.cfg_mask_token_phone * cfg_mask
            #tone_tokens = tone_tokens * (1 - cfg_mask) + self.cfg_mask_token_tone * cfg_mask
            #bpe_tokens = bpe_tokens * (1 - cfg_mask) + 32006 * cfg_mask

        ctx_mask,latent_lengths = self.obtain_ctx_mask(mels, mels_loss_mask) # [B, T, 1]
        inputs = {
            'phone': txt_tokens,
            #'tone': tone_tokens,
            'text_lens': txt_lengths,
            #'bpe': bpe_tokens,
            #'bpe_lengths': bpe_lengths,
            'lat': vae_latent,
            'lat_lens': latent_lengths,
            'ctx_mask': ctx_mask,
            'lat_ctx': vae_latent * ctx_mask,
            'text_mel_mask': self.sequence_mask(latent_lengths + txt_lengths) > 0,  #txt_lengths+latent_lengths的 padding mask
            'spk_cfg_mask': spk_cfg_mask,

            #'cfg_mask': cfg_mask
            #'mel2ph': mel2ph
        }

        ''' Diffusion training '''
        ret_dict = self.diff_model(inputs)
        pred_v, target_v = ret_dict['pred_v'].transpose(1, 2), ret_dict['target_v'].transpose(1, 2)

        losses, output = {}, {}
        #算prompt部分的loss,只是mask了mel-padding loss
        losses["diff_loss"] = self.mse_loss(
            #pred_v * loss_mask * (1 - ctx_mask),
            pred_v * loss_mask,
            #target_v * loss_mask * (1 - ctx_mask),
            target_v * loss_mask,
        )

        if torch.any(torch.isnan(losses['diff_loss'])) or torch.any(torch.isinf(losses['diff_loss'])):
            print("NaN/INF occurs in loss")
            print(sample['item_name'])
            print('NaN! NaN! NaN! NaN! NaN! NaN! NaN! NaN!')
            exit(1)

        return losses, output
    else:
        ctx_mask,latent_lengths = self.obtain_ctx_mask(mels, mels_loss_mask, infer=True) # B, T, 1
        # Make CFG inputs
        txt_tokens_ = torch.cat([txt_tokens, txt_tokens, torch.full(txt_tokens.size(), self.cfg_mask_token_phone, device=txt_tokens.device)], 0)
        # tone_tokens_ = torch.cat([torch.full(tone_tokens.size(), self.cfg_mask_token, device=tone_tokens.device), tone_tokens], 0)
        # bpe_tokens_ = torch.cat([torch.full(bpe_tokens.size(), 32006, device=bpe_tokens.device), bpe_tokens], 0)

        vae_latent_ = vae_latent.repeat(3, 1, 1)
        txt_lengths_ = txt_lengths.repeat(3)
        # bpe_lengths_ = bpe_lengths.repeat(2)
        latent_lengths_ = latent_lengths.repeat(3)
        ctx_mask_ = ctx_mask.repeat(3,1,1)
        # mel2ph_ = mel2ph.repeat(2, 1)
        inputs = {
            'phone': txt_tokens_,
            #'tone': tone_tokens,
            'text_lens': txt_lengths_,
            #'bpe': bpe_tokens,
            #'bpe_lengths': bpe_lengths,
            'ctx_mask': ctx_mask_,
            'lat_ctx': vae_latent_ * ctx_mask_,
            'lat': vae_latent_,
            'lat_lens': latent_lengths_,
            #'text_mel_mask': self.sequence_mask(latent_lengths_ + txt_lengths_) > 0,
            #'mel2ph': mel2ph
            'dur': None,
        }

        x = self.diff_model.inference(inputs, timesteps=25, cfg_w=5, use_seq_cfg=True, seq_cfg_w=[3, 4.5])
        x = vae_latent * ctx_mask + x * (1 - ctx_mask)
        if hparams.get('use_mel', False):
            outputs = x
        else:
            outputs = self.vae_model.decode(x.transpose(1,2))
            #outputs = self.vqvae_model.decode(x.transpose(1,2))
        return outputs

#@profile
def training_step(self, batch, batch_idx):
    if not batch:
        return None
    loss_weights = {}
    loss_output, _ = self.run_model(batch)
    total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad])
    #loss_output['batch_size'] = batch['txt_tokens'].size()[0]
    self.log('diff_loss', total_loss,prog_bar=True)
    return total_loss

#@profile
def validation_step(self, batch, batch_idx):
    if not batch:
        return None
    hparams = self.h
    loss_weights = {}
    model_out = self.run_model(batch, infer=True)
    #for loss_name, loss_value in loss_output.items():
    #    self.log(loss_name, loss_value,prog_bar=True)
    #total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad])
    #loss_output['batch_size'] = batch['txt_tokens'].size()[0]

    with torch.no_grad():
        if hparams.get('use_mel', False):
            y_hat = model_out.detach().cpu()
        else:
            y_hat = mel_out = model_out.get('mel_out').detach().cpu()
        mels = batch["mels"]
        with torch.inference_mode():
            with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=True):
                gt_vae_latent = self.vae_model.get_latent(mels).transpose(1,2)
                #gt_vae_latent = self.vqvae_model.get_latent(mels).transpose(1,2)

        if hparams.get('use_mel', False):
            y = mels
        else:
            #y = self.vqvae_model.decode(gt_vae_latent.transpose(1,2)).get('mel_out').detach().cpu()
            y = self.vae_model.decode(gt_vae_latent.transpose(1,2)).get('mel_out').detach().cpu()

        if self.trainer.global_rank == 0 and batch_idx < 10:
            with torch.no_grad():
                wav_gt = self.spec2wav(y.squeeze(0))
                wav_gt = wav_gt.astype(float)
                wav_gt = wav_gt * 32767
                wav_recon = self.spec2wav(y_hat.squeeze(0))
                wav_recon = wav_recon.astype(float)
                wav_recon = wav_recon * 32767
                #wav_gt = torch.from_numpy(wav_gt).unsqueeze(1)
                self.plot_mel(batch_idx, y_hat, y)
                import os
                save_wav_dir = os.path.join(self.logger.log_dir, 'Val_Audio')
                os.makedirs(save_wav_dir, exist_ok=True)
                #print('saving... val...wav...')
                wavfile.write(f'{save_wav_dir}/epoch_{self.current_epoch}_step_{self.global_step}_batch_{batch_idx}_gt.wav', 16000, wav_gt.astype(np.int16))
                wavfile.write(f'{save_wav_dir}/epoch_{self.current_epoch}_step_{self.global_step}_batch_{batch_idx}_recon.wav', 16000, wav_recon.astype(np.int16))
                #self.logger.experiment.add_audio(f'gt_audio_{batch_idx}', wav_gt, self.current_epoch, sample_rate=16000)
                #self.logger.experiment.add_audio('pre_audio', wav_recon.unsqueeze(0), self.current_epoch, sample_rate=16000)

    #return total_loss

def l1_loss(self, decoder_output, target):
    # decoder_output : B x T x n_mel
    # target : B x T x n_mel
    l1_loss = F.l1_loss(decoder_output, target, reduction='none')
    weights = weights_nonzero_speech(target)
    l1_loss = (l1_loss * weights).sum() / weights.sum()
    return l1_loss

def mse_loss(self, decoder_output, target):
    # decoder_output : B x T x n_mel
    # target : B x T x n_mel
    assert decoder_output.shape == target.shape
    mse_loss = F.mse_loss(decoder_output, target, reduction='none')
    weights = weights_nonzero_speech(target)
    mse_loss = (mse_loss * weights).sum() / weights.sum()
    return mse_loss

def ssim_loss(self, decoder_output, target, bias=6.0):
    # decoder_output : B x T x n_mel
    # target : B x T x n_mel
    assert decoder_output.shape == target.shape
    weights = weights_nonzero_speech(target)
    decoder_output = decoder_output[:, None] + bias
    target = target[:, None] + bias
    ssim_loss = 1 - ssim(decoder_output, target, size_average=False)
    ssim_loss = (ssim_loss * weights).sum() / weights.sum()
    return ssim_loss

def spec2wav(self, mel, **kwargs):
    device = self.device
    with torch.no_grad():
        if isinstance(mel, np.ndarray):
            mel = torch.FloatTensor(mel)
        mel = mel.unsqueeze(0).to(device)
        mel = mel.transpose(2, 1)
        y = self.vocoder.infer(mel).view(-1)
    y = y.to(torch.float32)
    wav_out = y.cpu().numpy()
    return wav_out

def plot_mel(self, batch_idx, spec_out, spec_gt=None, name=None, title='', f0s=None, dur_info=None):
    #vmin = hparams['mel_vmin']
    #vmax = hparams['mel_vmax']
    if len(spec_out.shape) == 3:
        spec_out = spec_out[0]
    if isinstance(spec_out, torch.Tensor):
        spec_out = spec_out.to(torch.float32)
        spec_out = spec_out.cpu().numpy()
    if spec_gt is not None:
        if len(spec_gt.shape) == 3:
            spec_gt = spec_gt[0]
        if isinstance(spec_gt, torch.Tensor):
            spec_gt = spec_gt.to(torch.float32)
            spec_gt = spec_gt.cpu().numpy()
        max_len = max(len(spec_gt), len(spec_out))
        if max_len - len(spec_gt) > 0:
            spec_gt = np.pad(spec_gt, [[0, max_len - len(spec_gt)], [0, 0]], mode='constant',
                             constant_values=vmin)
        if max_len - len(spec_out) > 0:
            spec_out = np.pad(spec_out, [[0, max_len - len(spec_out)], [0, 0]], mode='constant',
                              constant_values=vmin)
        spec_out = np.concatenate([spec_out, spec_gt], -1)
    name = f'mel_val_{batch_idx}' if name is None else name
    self.logger.experiment.add_figure(name, spec_to_figure(
        spec_out,title=title), self.global_step)

def obtain_ctx_mask(self, mels, loss_mask, infer=False, dtype=torch.float16):
    #mel2ph = sample["mel2ph"][:, :: self.vae_stride]  # [B, T_mel/8]
    #latent_lengths = (mel2ph > 0).sum(-1)          #[B, 去掉padding后的T_mel/8]; 目的是为了mask真实长度的50%
    #B, T = latent_lengths.shape[0], latent_lengths.max()   #T: txt_token_length
    mels = mels * loss_mask
    mels = mels[:, :: self.vae_stride].sum(-1)
    latent_lengths = (torch.abs(mels - 0) > 1e-3).sum(-1) 
    #B, T = latent_lengths.shape[0], latent_lengths.max()
    B, T = latent_lengths.shape[0], mels.shape[-1]
    device = mels.device

    if infer:
        # inference with 50% mask
        ratio = torch.ones((B, 1), device=device) * 0.5
    else:
        if self.h.get("random_ctx", True):
            ratio = torch.rand((B, 1), device=device) * self.mask_r
        else:
            ratio = torch.ones((B, 1), device=device) * self.mask_r

    end_idx = latent_lengths[:, None] * ratio
    end_idx = (end_idx - 1).clamp(min=1)
    ctx_mask = torch.arange(T, device=device)[None, :]
    ctx_mask = (ctx_mask < end_idx) # ctx_mask looks like: [1, 1, 1, 0, 0, 0, 0]
    ctx_mask = ctx_mask[:,:,None].to(dtype)
    #if not infer:
    #    breakpoint()
    return ctx_mask,latent_lengths

def sequence_mask(self, seq_lens, max_len=None, device='cpu', dtype=torch.float16):
    if max_len is None:
        max_len = seq_lens.max()
    mask = torch.arange(max_len).unsqueeze(0).to(seq_lens.device) # [1, t]
    mask = mask < (seq_lens.unsqueeze(1)) # [1, t] + [b, 1] = [b, t]
    mask = mask.to(dtype)
    return mask

def configure_optimizers(self):
    optimizer_gen = torch.optim.AdamW(
        self.gen_params, lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0, eps=1e-07
    )
    return optimizer_gen

def build_scheduler(self, optimizer):
    from utils.nn.schedulers import WarmupSchedule
    return WarmupSchedule(optimizer, 0.0001, 2000)

def on_before_optimization(self, opt_idx):
    hparams = self.h
    if hparams.gradient_clip_norm > 0:
        torch.nn.utils.clip_grad_norm_(self.parameters(), hparams.gradient_clip_norm)
    if hparams.gradient_clip_val > 0:
        torch.nn.utils.clip_grad_value_(self.parameters(), hparams.gradient_clip_val)
Sharpness1i commented 1 month ago

defaults:

project: speechdit_600w_baseline_agpu

project: test2

resume_weights_only: false

Lightning Trainer

trainer: accelerator: gpu devices: auto strategy: ddp_find_unused_parameters_true precision: 16 max_steps: 40_000_000 val_check_interval: 1000 benchmark: false

sample_rate: 16000 hop_length: 160 num_mels: 160 n_fft: 800 win_length: 800

root_dir: '/mnt/yuyin1/cbu-tts/tmp_data' item_wav_map: 'data/item_wav_map_600w.json'

Dataset Configuration

train_dataset: target: fish_speech.datasets.speechdit_dataset.SpeechDitDataset sample_rate: ${sample_rate} hop_length: ${hop_length} root_dir: ${root_dir} item_wav_map: ${item_wav_map} split: 'train' save_items_fp: /mnt/yuyin1/zuojialong/all_data/data_items.json save_lens_fp: '/mnt/yuyin1/zuojialong/all_data/lengths_list.json' category:

Cn: ['aishell3','didispeech', "APY210615005", "BB96", "hw_data", "aidatatang", "magic_data", "st-cmds"]

# En: ["common", "english_cn", "english_usa", "english_usa_cn", "LibriTTS", "LibriTTS2", "Ljspeech", "VCTK", "hifitts", "hifitts2"]
En: ["common"]
# style: ["cn1", "cn_en", "en1"]
# emotion: ["BB96_qinggan", "emotion_data"]
# bigdata1w: ["entertainment", "entertainment2", "health", "health2", "life", "life2"]
# test: ["test_en"]

val_dataset: target: fish_speech.datasets.speechdit_dataset.SpeechDitDataset sample_rate: ${sample_rate} hop_length: ${hop_length} root_dir: ${root_dir} item_wav_map: ${item_wav_map} split: 'valid' category: test: ["test_cn"]

data: target: fish_speech.datasets.speechdit_dataset.FlowTokenDecoderDataModule train_dataset: ${train_dataset} val_dataset: ${val_dataset} num_workers: 8 batch_size: 64 val_batch_size: 1 max_tokens: 12800

model_config:

frontend

phone_embed_dim: 512 tone_embed_dim: 128 n_phone: 225 n_tone: 32 local_cond_dim: 512 time_embed_dim: 256 local_cond_project_type: "linear" # conv local_cond_conv_kernel: 9 local_cond_conv_padding: 4

llama

encoder_dim: 1024 encoder_n_layers: 24 encoder_n_heads: 16 encoder_n_kv_heads: null mlp_extend: null max_seq_len: 8192 multiple_of: 256 # make SwiGLU hidden layer size multiple of large power of 2 norm_eps: 1e-5 dropout: 0.0 ffn_dim_multiplier: null use_causal_attn: False causal: False use_window_mask: False window_size: [-1, -1] window_type: "elemwise" # elemwise, blockwise llama_provider: "ctiga"

speaker encoder

spk_e_dim: 1024 spk_embed_dim: 512

postnet

postnet_type: "linear" # conv postnet_kernel: 3 target: "bn" prompt_feature: "bn"

in_channels: 24

out_channels: 24

mel_dim: 160

in_channels: int = 8

out_channels: int = 8

use_textprefix: True bias: False target_type: "vector_field"

for uniform t

min_t: 0.0 max_t: 1.0 lognormal_mean: 0.0 lognormal_std: 1.0 p_max_t: 0.0 use_seg_embed: True use_bn_eos_bos: True max_phone_len: 2000 max_bn_len: 4000 flashattn_version: "2.3" use_expand_ph: False zero_xt_prompt: False use_qk_norm: False use_cache: False max_cache_batch_size: 10 use_bpe: False

training

mask_r: 0.5 use_seq_cfg: True random_ctx: False

Model Configuration

model: target: fish_speech.models.dit_llama.SpeechDitTask sample_rate: ${sample_rate} hop_length: ${hop_length} hparams: ${model_config}

generator: target: fish_speech.models.dit_llama.dit_net.Diffusion hp: ${model_config}

optimizer: target: torch.optim.AdamW partial: true lr: 1e-4 betas: [0.9, 0.98] eps: 1e-6

lr_scheduler: target: transformers.get_cosine_schedule_with_warmup partial: true num_warmup_steps: 4000 num_training_steps: 2000000

callbacks: grad_norm_monitor: sub_module:

Sharpness1i commented 1 month ago

import random import math from dataclasses import dataclass from pathlib import Path from typing import Optional import sys import os from tqdm import tqdm sys.path.append('')

import librosa import numpy as np import torch import torchaudio import torch.distributed as dist from lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, Sampler from fish_speech.datasets.utils import DistributedDynamicBatchSampler from torch.utils.data.distributed import DistributedSampler from fish_speech.datasets.phone_text import build_token_encoder,TokenTextEncoder from fish_speech.utils.indexed_datasets import IndexedDataset import h5py import json import glob

from fish_speech.utils import RankedLogger

logger = RankedLogger(name, rank_zero_only=False)

def get_melspectrogram(wav_path, fft_size=800, hop_size=160, win_length=800, window="hann", num_mels=160, fmin=0, fmax=8000, eps=1e-6, sample_rate=16000, center=False, mel_basis=None):

# Load wav
if isinstance(wav_path, str):
    wav, _ = librosa.core.load(wav_path, sr=sample_rate)
else:
    wav = wav_path

# Pad wav to the multiple of the win_length
if len(wav) % win_length < win_length - 1:
    wav = np.pad(wav, (0, win_length - 1 - (len(wav) % win_length)), mode='constant', constant_values=0.0)

# get amplitude spectrogram
x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
                      win_length=win_length, window=window, center=center)
linear_spc = np.abs(x_stft)  # (n_bins, T)

# get mel basis
fmin = 0 if fmin == -1 else fmin
fmax = sample_rate / 2 if fmax == -1 else min(fmax, sample_rate // 2)
if mel_basis is None:
    mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)

# calculate mel spec
mel = mel_basis @ linear_spc
mel = np.log10(np.maximum(eps, mel))  # (n_mel_bins, T)
return mel  # (C,T)

class SpeechDitDataset(Dataset): def init( self, sample_rate: int = 22050, hop_length: int = 256, root_dir: str ="", category = None, item_wav_map: str="", split = 'train', save_items_fp = "", save_lens_fp = "", ): super().init()

    self.sample_rate = sample_rate
    self.hop_length = hop_length
    self.root_dir = root_dir

    with open(item_wav_map,'r') as fp:
        self.item_wavfn_map = json.load(fp)

    if split == 'train':
        with open(save_items_fp,'r') as fp:
            self.data_items = json.load(fp)
        with open(save_lens_fp,'r') as fp:
            self.lengths_list = json.load(fp)
    else:
        self.lengths_list = []
        self.filter_data = []
        for lg in category:
            for ds in category[lg]:
                metadata_dir = os.path.join(root_dir, f"{lg}/{ds}/metadata")
                b_datas = glob.glob(f'{metadata_dir}/*.data')
                for b_data in b_datas:
                    ds = IndexedDataset(b_data[:-5])
                    for i in tqdm(range(len(ds))):
                        item = ds[i]
                        item_name = item['item_name']
                        code_len = len(item['code'])
                        ph_token = item['ph']
                        if code_len>1000 or code_len<50:
                            continue
                        if self.item_wavfn_map.get(item_name,None) is None:
                            continue
                        tmp_item = {'item_name':item_name,'wav_fn':self.item_wavfn_map[item_name],'ph_token':ph_token}
                        self.filter_data.append(tmp_item)
                        self.lengths_list.append(code_len*2)
        self.data_items = self.filter_data
        self.data_items = self.data_items[:25]
        self.lengths_list = self.lengths_list[:25]

def __len__(self):
    return len(self.data_items)

def get_item(self, idx):
    item = self.data_items[idx]
    audio_file = item['wav_fn']
    ph_token = np.array(item['ph_token']).astype(int) # [T]

    ### bigvgan mel
    mel = get_melspectrogram(audio_file)

    ph_token = torch.LongTensor(ph_token)
    mel = torch.from_numpy(mel)

    return {
        "mel": mel,
        "txt_token":ph_token
    }

def get_item_exception(self,idx):
    try:
        return self.get_item(idx)
    except Exception as e:
        import traceback

        traceback.print_exc()
        logger.error(f"Error loading {self.data_items[idx]}: {e}")
        return None

def __getitem__(self, idx):
    if isinstance(idx, list):
        batch = [self.get_item_exception(i) for i in idx]
        return batch
    else:
        return self.get_item_exception(idx)

@dataclass class FlowTokenDecoderCollator:

def __call__(self, batch):
    batch = batch[0]
    batch = [x for x in batch if x is not None]

    txt_lengths = torch.tensor([x["txt_token"].shape[-1] for x in batch])
    txt_maxlen = txt_lengths.max()

    mel_lengths = torch.tensor([x["mel"].shape[-1] for x in batch])
    mel_maxlen = mel_lengths.max()

    # Rounds up to nearest multiple of 2 (audio_lengths)
    txt_tokens = []
    mels = []
    for x in batch:
        txt_tokens.append(
            torch.nn.functional.pad(x["txt_token"], (0, txt_maxlen - x["txt_token"].shape[-1]),value=0)
        )

        mels.append(
            torch.nn.functional.pad(x["mel"], (0, mel_maxlen - x["mel"].shape[-1]))
        )

    return {
        "txt_tokens": torch.stack(txt_tokens),
        "txt_lengths": txt_lengths,
        "mels": torch.stack(mels).transpose(1,2),
        "mel_lengths": mel_lengths,  
    }

@dataclass class FlowTokenDecoderCollator_Val:

def __call__(self, batch):
    batch = [x for x in batch if x is not None]

    txt_lengths = torch.tensor([x["txt_token"].shape[-1] for x in batch])
    txt_maxlen = txt_lengths.max()

    mel_lengths = torch.tensor([x["mel"].shape[-1] for x in batch])
    mel_maxlen = mel_lengths.max()

    # Rounds up to nearest multiple of 2 (audio_lengths)
    txt_tokens = []
    mels = []
    for x in batch:
        txt_tokens.append(
            torch.nn.functional.pad(x["txt_token"], (0, txt_maxlen - x["txt_token"].shape[-1]),value=0)
        )

        mels.append(
            torch.nn.functional.pad(x["mel"], (0, mel_maxlen - x["mel"].shape[-1]))
        )

    return {
        "txt_tokens": torch.stack(txt_tokens),
        "txt_lengths": txt_lengths,
        "mels": torch.stack(mels).transpose(1,2),
        "mel_lengths": mel_lengths,
    }

class FlowTokenDecoderDataModule(LightningDataModule): def init( self, train_dataset: SpeechDitDataset, val_dataset: SpeechDitDataset, batch_size: int = 32, num_workers: int = 4, val_batch_size: Optional[int] = None, max_tokens: int = 6400 # Maximum number of tokens per batch ): super().init()

    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.batch_size = batch_size
    self.val_batch_size = val_batch_size or batch_size
    self.num_workers = num_workers

    self.args = {'max_num_tokens':max_tokens,'num_buckets':6,'audio_max_length':20,'encodec_sr':100,'seed':1234}

def train_dataloader(self):
    world_size = torch.distributed.get_world_size()
    train_sampler = DistributedDynamicBatchSampler(
        self.train_dataset, 
        self.args, 
        num_replicas=world_size, 
        rank=torch.distributed.get_rank(), 
        shuffle=True, 
        seed=self.args['seed'], 
        drop_last=True, 
        lengths_list=self.train_dataset.lengths_list, 
        verbose=True, 
        batch_ordering='ascending',
        epoch=0)
    return DataLoader(
        self.train_dataset,
        sampler=train_sampler,
        collate_fn=FlowTokenDecoderCollator(),
        num_workers=self.num_workers//world_size,
        persistent_workers=True,
    )

def val_dataloader(self):
    return DataLoader(
        self.val_dataset,
        batch_size=self.val_batch_size,
        collate_fn=FlowTokenDecoderCollator_Val(),
        num_workers=self.num_workers,
        persistent_workers=True,
    )