XCYu-0903 / ROSE

ROSE: A Recognition-Oriented Speech Enhancement Framework in Air Traffic Control Using Multi-Objective Learning
MIT License
1 stars 0 forks source link

release model? #1

Closed jeremy110 closed 10 months ago

jeremy110 commented 10 months ago

您好 想請問會release model?

我正在實現這篇論文,使用Voice Bank + DEMAND 資料集訓練但出來結果相差很多... 另外論文當中這部分有點奇怪,在equation (2)所表示的是elementwise adding,但在圖上最後卻是concat,這部分是依照哪個為主?

如果您願意回覆,非常感謝

jeremy110 commented 10 months ago
import math

import torch as th
from torch import nn
from torch.nn import functional as F

class BLSTM(nn.Module):
    def __init__(self, dim, layers = 2, bi = True):
        super().__init__()
        klass = nn.LSTM
        self.lstm = klass(bidirectional = bi, num_layers = layers, hidden_size = dim, input_size = dim)
        self.linear = None
        if bi:
            self.linear = nn.Linear(2 * dim, dim)

    def forward(self, x, hidden = None):
        x, hidden = self.lstm(x, hidden)
        if self.linear:
            x = self.linear(x)
        return x, hidden

class CSAtt(nn.Module): # Channel and Sequence Attention
    def __init__(self, chin, r = 2):
        super(CSAtt, self).__init__()

        self.seq_path = nn.Sequential(
                            nn.Conv1d(chin, 1, 1, bias = False),
                            nn.Sigmoid()
                        )

        self.ch_path =  nn.Sequential(
                            nn.AdaptiveAvgPool1d(1), # GAP

                            nn.Conv1d(chin, chin // r, 1, bias = False), # Squeeze
                            nn.ReLU(inplace = True),

                            nn.Conv1d(chin // r, chin, 1, bias = False), # Excitation
                            nn.Sigmoid()
                        )

    def forward(self, x):
        '''
            x: (B, C, L)
            -c: channels
            -L: length
        '''
        x_path = x * self.seq_path(x).expand_as(x)
        x_ch = x * self.ch_path(x).expand_as(x)

        return x_path + x_ch

class ABSF(nn.Module): # Attention-based Skip-fusion
    def __init__(self, chin):
        super(ABSF, self).__init__()

        self.E_conv = nn.Conv1d(chin, chin, 1, bias = False)
        self.D_conv = nn.Conv1d(chin, chin, 1, bias = False)

        self.sigmoid=  nn.Sigmoid()
        self.conv1 = nn.Conv1d(chin, chin, 1, bias = False)

    def forward(self, E_i, D_i1):
        '''
            E_i, D_i1: (B, C, L)
        '''
        B_i = self.sigmoid(self.E_conv(E_i) + self.D_conv(D_i1))
        A_i = self.sigmoid(self.conv1(B_i))

        return th.cat((E_i * A_i, D_i1), dim = 1)  # concat or add??

class ROSE(nn.Module):
    """
    ROSE speech enhancement model.
    Args:
        - chin (int): number of input channels.
        - chout (int): number of output channels.
        - hidden (int): number of initial hidden channels.
        - depth (int): number of layers.
        - kernel_size (int): kernel size for each layer.
        - stride (int): stride for each layer.
    """

    def __init__(self,
                 chin = 1,
                 chout = 1,
                 hidden = 48,
                 depth = 5,
                 kernel_size = 8,
                 stride = 4,
                 max_H = 768
                 ):

        super(ROSE, self).__init__()

        self.chin = chin
        self.chout = chout
        self.hidden = hidden
        self.depth = depth
        self.kernel_size = kernel_size
        self.stride = stride
        self.resample = 1

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.absf = nn.ModuleList()

        for i in range(depth):
            self.encoder.append(nn.Sequential(
                nn.Conv1d(chin, hidden, kernel_size, stride, bias = False), # CNN block
                nn.ReLU(inplace = True),
                nn.Conv1d(hidden, hidden * 2, 1, bias = False), 
                nn.GLU(dim = 1),

                CSAtt(hidden)
            ))

            chin = hidden
            d_hidden = hidden * 2
            if i == 0:
                # no relu at end
                self.decoder.append(nn.Sequential(
                    CSAtt(chin * 2),

                    nn.Conv1d(chin * 2, hidden * 2, 1, bias = False), 
                    nn.GLU(dim = 1),
                    nn.ConvTranspose1d(hidden, chout, kernel_size, stride, bias = False),
                ))
            else:
                self.decoder.insert(0, nn.Sequential(
                    CSAtt(d_hidden),

                    nn.Conv1d(d_hidden, d_hidden * 2, 1, bias = False), 
                    nn.GLU(dim = 1),
                    nn.ConvTranspose1d(d_hidden, chout, kernel_size, stride, bias = False),
                    nn.ReLU()
                ))

            chout = hidden
            self.absf.append(ABSF(hidden))
            hidden *= 2
            hidden = min(hidden, max_H)

        self.lstm = BLSTM(chin, bi = True)
        self.absf = self.absf[::-1]

    def valid_length(self, length):
        """
        Return the nearest valid length to use with the model so that
        there is no time steps left over in a convolutions, e.g. for all
        layers, size of the input - kernel_size % stride = 0.
        If the mixture has a valid length, the estimated sources
        will have exactly the same length.
        """
        length = math.ceil(length * self.resample)
        for idx in range(self.depth):
            length = math.ceil((length - self.kernel_size) / self.stride) + 1
            length = max(length, 1)
        for idx in range(self.depth):
            length = (length - 1) * self.stride + self.kernel_size
        length = int(math.ceil(length / self.resample))
        return int(length)

    def forward(self, mix):
        '''
            mix: (B, C, L)
        '''
        if mix.dim() == 2:
            mix = mix.unsqueeze(1)

        mean = mix.mean(dim=(1, 2), keepdim=True)
        std = mix.std(dim=(1, 2), keepdim=True)
        mix = (mix - mean) / (1e-5 + std)

        length = mix.shape[-1]
        x = mix
        x = F.pad(x, (0, self.valid_length(length) - length))

        skips_mask = []
        for encode in self.encoder:
            x = encode(x)
            # print('encode', x.size())
            skips_mask.append(x)

        x = x.permute(2, 0, 1)
        x, _ = self.lstm(x)
        x = x.permute(1, 2, 0)

        x_mask = x
        for idx, decode in enumerate(self.decoder):
            skip = skips_mask.pop(-1)
            # print(f'skip:{idx}', skip.size())
            # print('x_mask', x_mask.size())
            x_mask = self.absf[idx](skip, x_mask)
            x_mask = decode(x_mask)
            # print('decode', x_mask.size())

        out = x_mask[..., :length]

        return std * out + mean

if __name__ == "__main__":
    x = th.rand([4, 32000])
    model = ROSE()
    out = model(x)
    print(out.size())

這是我network的部分,是採用concat的方式,不知道能否告訴我跟論文一樣嗎?

XCYu-0903 commented 10 months ago

@jeremy110 你好,感谢关注我们的工作。现答复如下:

  1. 按照文章所述,我们在ABSF模块中采用的是elementary add操作,但是实际上concat操作我们进行了实验验证,效果很接近。图中描述与文章不一样这是一个疏忽,感谢你的指出。

  2. 很高兴你对我们的工作感兴趣,模型复现比较接近。针对你所提到的结果不佳,我想还有一些原因:

  1. 由于本文正在投稿中,后续中稿的话会考虑将模型开源。
jeremy110 commented 10 months ago

您好 感謝您快速的回覆

1.的確在我自己實現,兩個出來的分數是差不多,感謝您的說明

2.1 針對loss function我測試過MultiResolutionSTFTLoss( https://github.com/JaeBinCHA7/DEMUCS-for-Speech-Enhancement/blob/main/utils/loss.py ),也將其中SpectralConvergengeLoss改成以MelSpectrogramLoss如下,應該缺少MFCC的loss而已,但以上兩個訓練出來的pesq大概落在2.7,stoi很差只有0.9,infer出來的音檔幾乎是不能聽的,我是打算將此model串接whisper,所以忽略掉MFCC

2.2 所提到的调整了CSAtt模块的层数设置,這部份是指有些Encoder或Decoder不加入CSAtt模块嗎?

2.3 我有試過各種切片長度,我原先採用論文當中4s,分數如2.1所提到的,之後我將秒數降成2s,pesq大概落在2.9,stoi到0.92

  1. 好的,了解感謝說明

再次感謝您的回覆

loss function:

# Copyright 2021 Tomoki Hayashi
#  MIT License (https://opensource.org/licenses/MIT)

"""Mel-spectrogram loss modules."""

# from distutils.version import LooseVersion

import librosa
import torch
import torch.nn.functional as F

is_pytorch_17plus = True #LooseVersion(torch.__version__) >= LooseVersion("1.7")

class MelSpectrogram(torch.nn.Module):
    """Calculate Mel-spectrogram."""

    def __init__(
        self,
        fs=22050,
        fft_size=1024,
        hop_size=256,
        win_length=None,
        window="hann",
        num_mels=80,
        fmin=80,
        fmax=7600,
        center=True,
        normalized=False,
        onesided=True,
        eps=1e-10,
        log_base=10.0,
    ):
        """Initialize MelSpectrogram module."""
        super().__init__()
        self.fft_size = fft_size
        if win_length is None:
            self.win_length = fft_size
        else:
            self.win_length = win_length
        self.hop_size = hop_size
        self.center = center
        self.normalized = normalized
        self.onesided = onesided
        if window is not None and not hasattr(torch, f"{window}_window"):
            raise ValueError(f"{window} window is not implemented")
        self.window = window
        self.eps = eps

        fmin = 0 if fmin is None else fmin
        fmax = fs / 2 if fmax is None else fmax
        melmat = librosa.filters.mel(
            sr=fs,
            n_fft=fft_size,
            n_mels=num_mels,
            fmin=fmin,
            fmax=fmax,
        )
        self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
        self.stft_params = {
            "n_fft": self.fft_size,
            "win_length": self.win_length,
            "hop_length": self.hop_size,
            "center": self.center,
            "normalized": self.normalized,
            "onesided": self.onesided,
        }
        if is_pytorch_17plus:
            self.stft_params["return_complex"] = False

        self.log_base = log_base
        if self.log_base is None:
            self.log = torch.log
        elif self.log_base == 2.0:
            self.log = torch.log2
        elif self.log_base == 10.0:
            self.log = torch.log10
        else:
            raise ValueError(f"log_base: {log_base} is not supported.")

    def forward(self, x):
        """Calculate Mel-spectrogram.

        Args:
            x (Tensor): Input waveform tensor (B, T) or (B, 1, T).

        Returns:
            Tensor: Mel-spectrogram (B, #mels, #frames).

        """
        if x.dim() == 3:
            # (B, C, T) -> (B*C, T)
            x = x.reshape(-1, x.size(2))

        if self.window is not None:
            window_func = getattr(torch, f"{self.window}_window")
            window = window_func(self.win_length, dtype=x.dtype, device=x.device)
        else:
            window = None

        x_stft = torch.stft(x, window=window, **self.stft_params)
        # (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
        x_stft = x_stft.transpose(1, 2)
        x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2
        x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))

        x_mel = torch.matmul(x_amp, self.melmat)
        x_mel = torch.clamp(x_mel, min=self.eps)

        return self.log(x_mel).transpose(1, 2)

class MelSpectrogramLoss(torch.nn.Module):
    """Mel-spectrogram loss."""

    def __init__(
        self,
        fs=16000,
        fft_size=400,
        hop_size=160,
        win_length=400,
        window="hann",
        num_mels=80,
        fmin=80,
        fmax=7600,
        center=True,
        normalized=False,
        onesided=True,
        eps=1e-10,
        log_base=10.0,
    ):
        """Initialize Mel-spectrogram loss."""
        super().__init__()
        self.mel_spectrogram = MelSpectrogram(
            fs=fs,
            fft_size=fft_size,
            hop_size=hop_size,
            win_length=win_length,
            window=window,
            num_mels=num_mels,
            fmin=fmin,
            fmax=fmax,
            center=center,
            normalized=normalized,
            onesided=onesided,
            eps=eps,
            log_base=log_base,
        )

    def forward(self, y_hat, y):
        """Calculate Mel-spectrogram loss.

        Args:
            y_hat (Tensor): Generated single tensor (B, 1, T).
            y (Tensor): Groundtruth single tensor (B, 1, T).

        Returns:
            Tensor: Mel-spectrogram loss value.

        """
        mel_hat = self.mel_spectrogram(y_hat)
        mel = self.mel_spectrogram(y)
        # mel_loss = F.l1_loss(mel_hat, mel)
        # print(torch.norm(mel_hat - mel, p="fro") / torch.norm(mel_hat, p="fro"))

        mel_loss = torch.norm(mel_hat - mel, p="fro") / torch.norm(mel_hat, p="fro")
        return mel_loss

if __name__ == "__main__":
    mel_loss_fn = MelSpectrogramLoss()
    x = torch.rand(4, 1, 32000)
    y = torch.rand(4, 1, 32000)
    print(mel_loss_fn(x, y))
"""STFT-based Loss modules."""

import torch
import torch.nn.functional as F
from mel_loss import MelSpectrogramLoss

def stft(x, fft_size, hop_size, win_length, window=None):
    """Perform STFT and convert to magnitude spectrogram.
    Args:
        x (Tensor): Input signal tensor (B, T).
        fft_size (int): FFT size.
        hop_size (int): Hop size.
        win_length (int): Window length.
        window (str): Window function type.
    Returns:
        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
    """
    x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
    real = x_stft.real
    imag = x_stft.imag

    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
    return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)

class SpectralConvergengeLoss(torch.nn.Module):
    """Spectral convergence loss module."""

    def __init__(self):
        """Initilize spectral convergence loss module."""
        super(SpectralConvergengeLoss, self).__init__()

    def forward(self, x_mag, y_mag):
        """Calculate forward propagation.
        Args:
            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
        Returns:
            Tensor: Spectral convergence loss value.
        """
        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")

class LogSTFTMagnitudeLoss(torch.nn.Module):
    """Log STFT magnitude loss module."""

    def __init__(self):
        """Initilize los STFT magnitude loss module."""
        super(LogSTFTMagnitudeLoss, self).__init__()

    def forward(self, x_mag, y_mag):
        """Calculate forward propagation.
        Args:
            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
        Returns:
            Tensor: Log STFT magnitude loss value.
        """
        return F.l1_loss(torch.log(y_mag), torch.log(x_mag))

class STFTLoss(torch.nn.Module):
    """STFT loss module."""

    def __init__(self, fft_size=512, shift_size=100, win_length=400, window="hann_window"):
        """Initialize STFT loss module."""
        super(STFTLoss, self).__init__()
        self.fft_size = fft_size
        self.shift_size = shift_size
        self.win_length = win_length
        self.register_buffer("window", getattr(torch, window)(win_length))
        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()

    def forward(self, x, y):
        """Calculate forward propagation.
        Args:
            x (Tensor): Predicted signal (B, T).
            y (Tensor): Groundtruth signal (B, T).
        Returns:
            Tensor: Spectral convergence loss value.
            Tensor: Log STFT magnitude loss value.
        """
        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)

        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)

        return mag_loss

class CustomLoss(torch.nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

        self.loss_fn_1 = STFTLoss()
        self.loss_fn_2 = torch.nn.L1Loss()
        self.loss_fn_3 = MelSpectrogramLoss()

    def forward(self, preds, targets):
        mag_loss = 0.1 *self.loss_fn_1(preds, targets)
        mae_loss = self.loss_fn_2(preds, targets)

        mel_loss = 0.1 *self.loss_fn_3(preds, targets)

        loss = mel_loss + mag_loss + mae_loss

        return loss

if __name__ == "__main__":
    mel_loss_fn = CustomLoss()
    x = torch.rand(4, 32000)
    y = torch.rand(4, 32000)
    print(mel_loss_fn(x, y))
XCYu-0903 commented 10 months ago

@jeremy110 你好,现对你所提出的疑惑进行进一步回答:

  1. 关于CSAtt层数设置,可考虑取消在接近LSTM层的Encoder与Decoder中(通常保留Encoder前3层和Decoder后3层)。
  2. VOICEBANK-DEMAND数据切片长度可设置为1~3s,备选项有1s、1.5s、2s、2.5s、3s。偏移可尝试20%、50%。
  3. STFT相关参数设置可参考multi-resolution,根据信号原理,STFT参数设置与所选择的切片长度有密切关系。
  4. 实际上,pesq值在2.97及以上即达到本文模型的预期。
jeremy110 commented 10 months ago

您好 感謝您快速的回覆

  1. 好的,了解我會再嘗試看看
  2. 想請問偏移的部份,我看大部分的程式都是不到預期切片秒數,會重複幾次波形已滿足長度,超過預期秒數會random選擇起始點然後切片,所以您指的是超過預期秒數,會以偏移20%、50%去對該音檔做切片?
  3. 了解,感謝說明
  4. 了解,感謝說明

再次感謝您的回覆及說明

XCYu-0903 commented 10 months ago

@jeremy110 你好,关于切片和偏移的答复如下:

  1. 如果数据不足切片长度,则一般采取pad补齐切片长度;若数据超过指定切片长度,一些做法是按照你说的随机选择起点进行切片,另一些做法为按照指定切片长度对数据进行裁剪。
  2. 在我的回答里,偏移(也叫移动)是一种数据增强方法,即选择切片的多少百分比对原数据进行样本扩充。举例说明为:若定义切片长度为2s,偏移为50%,则偏移量为2*50%=1s,假设原始数据长度4s,则其所有切片为[0,2][1,3][2,4]。而本工作采用的是这种带偏移的切片设置。 若在说法上给你带来了误解,请谅解。
jeremy110 commented 10 months ago

您好 感謝您的回覆

我會再試試看的,十分感謝回覆及說明,預祝您中稿