Audio-WestlakeU / FullSubNet

PyTorch implementation of "FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."
https://fullsubnet.readthedocs.io/en/latest/
MIT License
532 stars 152 forks source link

Fast FullSubNet m=infinity #62

Open SEMLLYCAT opened 1 year ago

SEMLLYCAT commented 1 year ago

Thank you very much for your excellent work and I'm replicating the results from your paper(Fast FullSubNet). I have reproduced the result of m= 2,8, but when m=infinity was reproduced, I could not get the result described in your paper for the moment. If it's all right with you, I hope you could provide me the model structure with m=infinity about the sub-model is removed. Thank you very much for your help and i am looking forward to your reply!

haoxiangsnr commented 1 year ago

Thanks for your attention.

Please refer to the implementation below. If you run this model, you may find its performance is similar to $m=8$.

import torch
import torch.nn as nn
import torchaudio as audio
from torch.nn import functional
from torchinfo import summary

from audio_zen.model.base_model import BaseModel
from audio_zen.model.module.sequence_model import SequenceModel

class Model(BaseModel):
    def __init__(
        self,
        look_ahead,
        shrink_size,
        sequence_model,
        encoder_input_size,
        num_mels,
        noisy_input_num_neighbors,
        encoder_output_num_neighbors,
        norm_type="offline_laplace_norm",
        weight_init=False,
    ):
        """
        Simply FullSubNet.

        Notes:
            In this model, the encoder and bottleneck are corresponding to the fullband model and subband model, respectively.
        """
        super().__init__()
        assert sequence_model in (
            "GRU",
            "LSTM",
        ), f"{self.__class__.__name__} only support GRU and LSTM."

        # Encoder
        self.encoder = nn.Sequential(
            SequenceModel(
                input_size=64,
                hidden_size=384,
                output_size=0,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function=None,
            ),
            SequenceModel(
                input_size=384,
                hidden_size=257,
                output_size=64,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function="ReLU",
            ),
        )

        # Mel filterbank
        self.mel_scale = audio.transforms.MelScale(
            n_mels=num_mels,
            sample_rate=16000,
            f_min=0,
            f_max=8000,
            n_stft=encoder_input_size,
        )

        self.decoder_lstm = nn.Sequential(
            SequenceModel(
                input_size=64,
                hidden_size=512,
                output_size=0,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function=None,
            ),
            SequenceModel(
                input_size=512,
                hidden_size=512,
                output_size=257 * 2,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function=None,
            ),
        )

        self.look_ahead = look_ahead
        self.norm = self.norm_wrapper(norm_type)
        self.num_mels = num_mels
        self.noisy_input_num_neighbors = noisy_input_num_neighbors
        self.enc_output_num_neighbors = encoder_output_num_neighbors
        self.shrink_size = shrink_size

        if weight_init:
            self.apply(self.weight_init)

    # fmt: off
    def forward(self, mix_mag):
        """
        Args:
            mix_mag: noisy magnitude spectrogram

        Returns:
            The real part and imag part of the enhanced spectrogram

        Shapes:
            noisy_mag: [B, 1, F, T]
            return: [B, 2, F, T]
        """
        assert mix_mag.dim() == 4
        mix_mag = functional.pad(mix_mag, [0, self.look_ahead])  # Pad the look ahead
        batch_size, num_channels, num_freqs, num_frames = mix_mag.size()
        assert num_channels == 1, f"{self.__class__.__name__} takes a mag feature as inputs."

        # Mel filtering
        mix_mel_mag = self.mel_scale(mix_mag)  # [B, C, F_mel, T]
        _, _, num_freqs_mel, _ = mix_mel_mag.shape

        # Encoder - Fullband Model
        enc_input = self.norm(mix_mel_mag).reshape(batch_size, -1, num_frames)
        enc_output = self.encoder(enc_input).reshape(batch_size, num_channels, -1, num_frames)  # [B, C, F, T]

        dec_input = enc_output.reshape(batch_size, -1, num_frames)
        decoder_lstm_output = self.decoder_lstm(dec_input)  # [B * C, F * 2, T]

        # Decoder - Fullband Linear Model
        dec_output = decoder_lstm_output.reshape(batch_size, 2, num_freqs, num_frames)

        # Output
        output = dec_output[:, :, :, self.look_ahead:]

        return output

if __name__ == "__main__":
    with torch.no_grad():
        noisy_mag = torch.rand(1, 1, 257, 63)
        model = Model(
            look_ahead=2,
            shrink_size=16,
            sequence_model="LSTM",
            encoder_input_size=257,
            num_mels=64,
            noisy_input_num_neighbors=5,
            encoder_output_num_neighbors=0,
        )
        output = model(noisy_mag)
        print(summary(model, (1, 1, 257, 63), device="cpu"))