Closed jeremy110 closed 10 months ago
import math
import torch as th
from torch import nn
from torch.nn import functional as F
class BLSTM(nn.Module):
def __init__(self, dim, layers = 2, bi = True):
super().__init__()
klass = nn.LSTM
self.lstm = klass(bidirectional = bi, num_layers = layers, hidden_size = dim, input_size = dim)
self.linear = None
if bi:
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x, hidden = None):
x, hidden = self.lstm(x, hidden)
if self.linear:
x = self.linear(x)
return x, hidden
class CSAtt(nn.Module): # Channel and Sequence Attention
def __init__(self, chin, r = 2):
super(CSAtt, self).__init__()
self.seq_path = nn.Sequential(
nn.Conv1d(chin, 1, 1, bias = False),
nn.Sigmoid()
)
self.ch_path = nn.Sequential(
nn.AdaptiveAvgPool1d(1), # GAP
nn.Conv1d(chin, chin // r, 1, bias = False), # Squeeze
nn.ReLU(inplace = True),
nn.Conv1d(chin // r, chin, 1, bias = False), # Excitation
nn.Sigmoid()
)
def forward(self, x):
'''
x: (B, C, L)
-c: channels
-L: length
'''
x_path = x * self.seq_path(x).expand_as(x)
x_ch = x * self.ch_path(x).expand_as(x)
return x_path + x_ch
class ABSF(nn.Module): # Attention-based Skip-fusion
def __init__(self, chin):
super(ABSF, self).__init__()
self.E_conv = nn.Conv1d(chin, chin, 1, bias = False)
self.D_conv = nn.Conv1d(chin, chin, 1, bias = False)
self.sigmoid= nn.Sigmoid()
self.conv1 = nn.Conv1d(chin, chin, 1, bias = False)
def forward(self, E_i, D_i1):
'''
E_i, D_i1: (B, C, L)
'''
B_i = self.sigmoid(self.E_conv(E_i) + self.D_conv(D_i1))
A_i = self.sigmoid(self.conv1(B_i))
return th.cat((E_i * A_i, D_i1), dim = 1) # concat or add??
class ROSE(nn.Module):
"""
ROSE speech enhancement model.
Args:
- chin (int): number of input channels.
- chout (int): number of output channels.
- hidden (int): number of initial hidden channels.
- depth (int): number of layers.
- kernel_size (int): kernel size for each layer.
- stride (int): stride for each layer.
"""
def __init__(self,
chin = 1,
chout = 1,
hidden = 48,
depth = 5,
kernel_size = 8,
stride = 4,
max_H = 768
):
super(ROSE, self).__init__()
self.chin = chin
self.chout = chout
self.hidden = hidden
self.depth = depth
self.kernel_size = kernel_size
self.stride = stride
self.resample = 1
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.absf = nn.ModuleList()
for i in range(depth):
self.encoder.append(nn.Sequential(
nn.Conv1d(chin, hidden, kernel_size, stride, bias = False), # CNN block
nn.ReLU(inplace = True),
nn.Conv1d(hidden, hidden * 2, 1, bias = False),
nn.GLU(dim = 1),
CSAtt(hidden)
))
chin = hidden
d_hidden = hidden * 2
if i == 0:
# no relu at end
self.decoder.append(nn.Sequential(
CSAtt(chin * 2),
nn.Conv1d(chin * 2, hidden * 2, 1, bias = False),
nn.GLU(dim = 1),
nn.ConvTranspose1d(hidden, chout, kernel_size, stride, bias = False),
))
else:
self.decoder.insert(0, nn.Sequential(
CSAtt(d_hidden),
nn.Conv1d(d_hidden, d_hidden * 2, 1, bias = False),
nn.GLU(dim = 1),
nn.ConvTranspose1d(d_hidden, chout, kernel_size, stride, bias = False),
nn.ReLU()
))
chout = hidden
self.absf.append(ABSF(hidden))
hidden *= 2
hidden = min(hidden, max_H)
self.lstm = BLSTM(chin, bi = True)
self.absf = self.absf[::-1]
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length.
"""
length = math.ceil(length * self.resample)
for idx in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(length, 1)
for idx in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
length = int(math.ceil(length / self.resample))
return int(length)
def forward(self, mix):
'''
mix: (B, C, L)
'''
if mix.dim() == 2:
mix = mix.unsqueeze(1)
mean = mix.mean(dim=(1, 2), keepdim=True)
std = mix.std(dim=(1, 2), keepdim=True)
mix = (mix - mean) / (1e-5 + std)
length = mix.shape[-1]
x = mix
x = F.pad(x, (0, self.valid_length(length) - length))
skips_mask = []
for encode in self.encoder:
x = encode(x)
# print('encode', x.size())
skips_mask.append(x)
x = x.permute(2, 0, 1)
x, _ = self.lstm(x)
x = x.permute(1, 2, 0)
x_mask = x
for idx, decode in enumerate(self.decoder):
skip = skips_mask.pop(-1)
# print(f'skip:{idx}', skip.size())
# print('x_mask', x_mask.size())
x_mask = self.absf[idx](skip, x_mask)
x_mask = decode(x_mask)
# print('decode', x_mask.size())
out = x_mask[..., :length]
return std * out + mean
if __name__ == "__main__":
x = th.rand([4, 32000])
model = ROSE()
out = model(x)
print(out.size())
這是我network的部分,是採用concat的方式,不知道能否告訴我跟論文一樣嗎?
@jeremy110 你好,感谢关注我们的工作。现答复如下:
按照文章所述,我们在ABSF模块中采用的是elementary add操作,但是实际上concat操作我们进行了实验验证,效果很接近。图中描述与文章不一样这是一个疏忽,感谢你的指出。
很高兴你对我们的工作感兴趣,模型复现比较接近。针对你所提到的结果不佳,我想还有一些原因:
您好 感謝您快速的回覆
1.的確在我自己實現,兩個出來的分數是差不多,感謝您的說明
2.1 針對loss function我測試過MultiResolutionSTFTLoss( https://github.com/JaeBinCHA7/DEMUCS-for-Speech-Enhancement/blob/main/utils/loss.py ),也將其中SpectralConvergengeLoss改成以MelSpectrogramLoss如下,應該缺少MFCC的loss而已,但以上兩個訓練出來的pesq大概落在2.7,stoi很差只有0.9,infer出來的音檔幾乎是不能聽的,我是打算將此model串接whisper,所以忽略掉MFCC
2.2 所提到的调整了CSAtt模块的层数设置,這部份是指有些Encoder或Decoder不加入CSAtt模块嗎?
2.3 我有試過各種切片長度,我原先採用論文當中4s,分數如2.1所提到的,之後我將秒數降成2s,pesq大概落在2.9,stoi到0.92
再次感謝您的回覆
loss function:
# Copyright 2021 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Mel-spectrogram loss modules."""
# from distutils.version import LooseVersion
import librosa
import torch
import torch.nn.functional as F
is_pytorch_17plus = True #LooseVersion(torch.__version__) >= LooseVersion("1.7")
class MelSpectrogram(torch.nn.Module):
"""Calculate Mel-spectrogram."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
):
"""Initialize MelSpectrogram module."""
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
self.eps = eps
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(
sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax,
)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
self.stft_params = {
"n_fft": self.fft_size,
"win_length": self.win_length,
"hop_length": self.hop_size,
"center": self.center,
"normalized": self.normalized,
"onesided": self.onesided,
}
if is_pytorch_17plus:
self.stft_params["return_complex"] = False
self.log_base = log_base
if self.log_base is None:
self.log = torch.log
elif self.log_base == 2.0:
self.log = torch.log2
elif self.log_base == 10.0:
self.log = torch.log10
else:
raise ValueError(f"log_base: {log_base} is not supported.")
def forward(self, x):
"""Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if x.dim() == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape(-1, x.size(2))
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length, dtype=x.dtype, device=x.device)
else:
window = None
x_stft = torch.stft(x, window=window, **self.stft_params)
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
x_stft = x_stft.transpose(1, 2)
x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
x_mel = torch.matmul(x_amp, self.melmat)
x_mel = torch.clamp(x_mel, min=self.eps)
return self.log(x_mel).transpose(1, 2)
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
fs=16000,
fft_size=400,
hop_size=160,
win_length=400,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
):
"""Initialize Mel-spectrogram loss."""
super().__init__()
self.mel_spectrogram = MelSpectrogram(
fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base,
)
def forward(self, y_hat, y):
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated single tensor (B, 1, T).
y (Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
# mel_loss = F.l1_loss(mel_hat, mel)
# print(torch.norm(mel_hat - mel, p="fro") / torch.norm(mel_hat, p="fro"))
mel_loss = torch.norm(mel_hat - mel, p="fro") / torch.norm(mel_hat, p="fro")
return mel_loss
if __name__ == "__main__":
mel_loss_fn = MelSpectrogramLoss()
x = torch.rand(4, 1, 32000)
y = torch.rand(4, 1, 32000)
print(mel_loss_fn(x, y))
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
from mel_loss import MelSpectrogramLoss
def stft(x, fft_size, hop_size, win_length, window=None):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
real = x_stft.real
imag = x_stft.imag
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
class SpectralConvergengeLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergengeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self, fft_size=512, shift_size=100, win_length=400, window="hann_window"):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.register_buffer("window", getattr(torch, window)(win_length))
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return mag_loss
class CustomLoss(torch.nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
self.loss_fn_1 = STFTLoss()
self.loss_fn_2 = torch.nn.L1Loss()
self.loss_fn_3 = MelSpectrogramLoss()
def forward(self, preds, targets):
mag_loss = 0.1 *self.loss_fn_1(preds, targets)
mae_loss = self.loss_fn_2(preds, targets)
mel_loss = 0.1 *self.loss_fn_3(preds, targets)
loss = mel_loss + mag_loss + mae_loss
return loss
if __name__ == "__main__":
mel_loss_fn = CustomLoss()
x = torch.rand(4, 32000)
y = torch.rand(4, 32000)
print(mel_loss_fn(x, y))
@jeremy110 你好,现对你所提出的疑惑进行进一步回答:
您好 感謝您快速的回覆
再次感謝您的回覆及說明
@jeremy110 你好,关于切片和偏移的答复如下:
您好 感謝您的回覆
我會再試試看的,十分感謝回覆及說明,預祝您中稿
您好 想請問會release model?
我正在實現這篇論文,使用Voice Bank + DEMAND 資料集訓練但出來結果相差很多... 另外論文當中這部分有點奇怪,在equation (2)所表示的是elementwise adding,但在圖上最後卻是concat,這部分是依照哪個為主?
如果您願意回覆,非常感謝