Open SEMLLYCAT opened 1 year ago
Thanks for your attention.
Please refer to the implementation below. If you run this model, you may find its performance is similar to $m=8$.
import torch
import torch.nn as nn
import torchaudio as audio
from torch.nn import functional
from torchinfo import summary
from audio_zen.model.base_model import BaseModel
from audio_zen.model.module.sequence_model import SequenceModel
class Model(BaseModel):
def __init__(
self,
look_ahead,
shrink_size,
sequence_model,
encoder_input_size,
num_mels,
noisy_input_num_neighbors,
encoder_output_num_neighbors,
norm_type="offline_laplace_norm",
weight_init=False,
):
"""
Simply FullSubNet.
Notes:
In this model, the encoder and bottleneck are corresponding to the fullband model and subband model, respectively.
"""
super().__init__()
assert sequence_model in (
"GRU",
"LSTM",
), f"{self.__class__.__name__} only support GRU and LSTM."
# Encoder
self.encoder = nn.Sequential(
SequenceModel(
input_size=64,
hidden_size=384,
output_size=0,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=None,
),
SequenceModel(
input_size=384,
hidden_size=257,
output_size=64,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function="ReLU",
),
)
# Mel filterbank
self.mel_scale = audio.transforms.MelScale(
n_mels=num_mels,
sample_rate=16000,
f_min=0,
f_max=8000,
n_stft=encoder_input_size,
)
self.decoder_lstm = nn.Sequential(
SequenceModel(
input_size=64,
hidden_size=512,
output_size=0,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=None,
),
SequenceModel(
input_size=512,
hidden_size=512,
output_size=257 * 2,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=None,
),
)
self.look_ahead = look_ahead
self.norm = self.norm_wrapper(norm_type)
self.num_mels = num_mels
self.noisy_input_num_neighbors = noisy_input_num_neighbors
self.enc_output_num_neighbors = encoder_output_num_neighbors
self.shrink_size = shrink_size
if weight_init:
self.apply(self.weight_init)
# fmt: off
def forward(self, mix_mag):
"""
Args:
mix_mag: noisy magnitude spectrogram
Returns:
The real part and imag part of the enhanced spectrogram
Shapes:
noisy_mag: [B, 1, F, T]
return: [B, 2, F, T]
"""
assert mix_mag.dim() == 4
mix_mag = functional.pad(mix_mag, [0, self.look_ahead]) # Pad the look ahead
batch_size, num_channels, num_freqs, num_frames = mix_mag.size()
assert num_channels == 1, f"{self.__class__.__name__} takes a mag feature as inputs."
# Mel filtering
mix_mel_mag = self.mel_scale(mix_mag) # [B, C, F_mel, T]
_, _, num_freqs_mel, _ = mix_mel_mag.shape
# Encoder - Fullband Model
enc_input = self.norm(mix_mel_mag).reshape(batch_size, -1, num_frames)
enc_output = self.encoder(enc_input).reshape(batch_size, num_channels, -1, num_frames) # [B, C, F, T]
dec_input = enc_output.reshape(batch_size, -1, num_frames)
decoder_lstm_output = self.decoder_lstm(dec_input) # [B * C, F * 2, T]
# Decoder - Fullband Linear Model
dec_output = decoder_lstm_output.reshape(batch_size, 2, num_freqs, num_frames)
# Output
output = dec_output[:, :, :, self.look_ahead:]
return output
if __name__ == "__main__":
with torch.no_grad():
noisy_mag = torch.rand(1, 1, 257, 63)
model = Model(
look_ahead=2,
shrink_size=16,
sequence_model="LSTM",
encoder_input_size=257,
num_mels=64,
noisy_input_num_neighbors=5,
encoder_output_num_neighbors=0,
)
output = model(noisy_mag)
print(summary(model, (1, 1, 257, 63), device="cpu"))
Thank you very much for your excellent work and I'm replicating the results from your paper(Fast FullSubNet). I have reproduced the result of m= 2,8, but when m=infinity was reproduced, I could not get the result described in your paper for the moment. If it's all right with you, I hope you could provide me the model structure with m=infinity about the sub-model is removed. Thank you very much for your help and i am looking forward to your reply!