asteroid-team / asteroid

The PyTorch-based audio source separation toolkit for researchers
https://asteroid-team.github.io/
MIT License
2.27k stars 423 forks source link

Is it possible to use asteroid.engine.system.System in tandem with pl.LightningDataModule? #399

Closed actuallyaswin closed 3 years ago

actuallyaswin commented 3 years ago

I'm dealing with a weird scenario where I have two different audio denoising models and I want to train them in two specific procedures. I like how the System() class has a common_step function, so I want my models to be modular and support the System1 or System2 depending on what training procedure I'd like to follow.

But I'd like to avoid how the System class also has the dataloaders coupled into its instantiation. A DataModule is great because, depending on what init args I pass to it, I can get a specific set of trainloaders for train/val/test.

So in other words, I'd like my two flavors of trainloaders to be agnostic to the chosen model architecture. But I'd like the model to be wrapped with an appropriate System depending on what flavor I'm working with.

Does this make sense? Looking at the API Overview notebook, I don't see any mention of DataModules so that's why I'm curious.

jonashaag commented 3 years ago

I'm having difficulties understanding what you want, can you please be more specific to your use case and maybe sketch a potential solution, if you have any?

mpariente commented 3 years ago

I can answer better Monday but passing None to train_loader should do the trick

actuallyaswin commented 3 years ago

Thanks both of you for your interest. Yeah my specific experiment is complicated to explain. Let me paste some code snippets below. So first I have two types of neural networks, one which uses a GRU and one which uses ConvTasNet.

class NetworkRNN(nn.Module):

    # STFT parameters
    fft_size: int = 1024
    hop_length: int = 256
    window = nn.Parameter(torch.hann_window(fft_size), False)

    def __init__(self, hidden_size: int, num_layers: int = 2):

        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # create a neural network which predicts a TF binary ratio mask
        self.encoder = nn.GRU(
            input_size=int(self.fft_size // 2 + 1),
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True
        )
        self.decoder = nn.Sequential(
            nn.Linear(
                in_features=self.hidden_size,
                out_features=int(self.fft_size // 2 + 1)
            ),
            nn.Sigmoid()
        )

    def stft(self, waveform: torch.Tensor):
        """Calculates the Short-time Fourier transform (STFT)."""

        # perform the short-time Fourier transform
        spectrogram = torch.stft(
            waveform, self.fft_size, self.hop_length, window=self.window
        )

        # swap seq_len & feature_dim of the spectrogram (for RNN processing)
        spectrogram = spectrogram.permute(0, 2, 1, 3)

        # calculate the magnitude spectrogram
        magnitude_spectrogram = torch.sqrt(spectrogram[..., 0]**2 +
                                           spectrogram[..., 1]**2)

        return (spectrogram, magnitude_spectrogram)

    def istft(self, spectrogram: torch.Tensor,
              mask: Optional[torch.Tensor] = None):
        """Calculates the inverse Short-time Fourier transform (ISTFT)."""

        # apply a time-frequency mask if provided
        if mask is not None:
            spectrogram[..., 0] *= mask
            spectrogram[..., 1] *= mask

        # swap seq_len & feature_dim of the spectrogram (undo RNN processing)
        spectrogram = spectrogram.permute(0, 2, 1, 3)

        # perform the inverse short-time Fourier transform
        waveform = torch.istft(
            spectrogram, self.fft_size, self.hop_length, window=self.window
        )

        return waveform

    @auto_move_data
    def forward(self, waveform):

        # convert waveform to spectrogram
        (X, X_magnitude) = self.stft(waveform)

        # generate a time-frequency mask
        H = self.encoder(X_magnitude)[0]
        Y = self.decoder(H)
        Y = Y.reshape_as(X_magnitude)

        # convert masked spectrogram back to waveform
        denoised = self.istft(X, mask=Y)
        residual = self.istft(X, mask=(1-Y))

        return (denoised, residual)

class NetworkCTN(nn.Module):

    def __init__(self):
        super(NetworkCTN).__init__()
        self.network = ConvTasNet(n_src=2)

    @auto_move_data
    def forward(self, waveform):
        output = self.network(waveform)
        denoised = output[..., 0, :]
        residual = output[..., 1, :]
        return (denoised, residual)

Both networks here are meant to do speech enhancement and their forward routines output a set of denoised waveforms and residual sources (i.e. noise) waveforms.

I want to do a comparison of these two networks speech enhancement performance-wise depending on whether I train them with unimodal data batches or bimodal / contrastive data batches. I could be using asteroid.engine.system.System for this but I'm just manually subclassing LightningModule for now.

class System(LightningModule):

    def __init__(
        self,
        network: nn.Module,
        loss_function: str = 'mse',
        learning_rate: float = 1e-3,
    ):
        super().__init__()
        self.network = network
        self.learning_rate = learning_rate
        try:
            self.loss_fn = {
                'mse': SingleSrcMSE(),
                'sdr': SingleSrcNegSDR('sdsdr'),
                'stoi': SingleSrcNegSTOI(16000, True, True),
            }[loss_function]
        except KeyError:
            raise ValueError('Supported losse are: {"mse", "sdr", "stoi"}.')

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.network.parameters(),
            lr=self.learning_rate
        )

    def forward(self, x):
        return self.network(x)

    def training_step(self, batch, batch_idx):
        # unpack batch
        (mixture, target_speech, target_noise) = batch

        # denoise input
        estimate_speech, estimate_noise = self(mixture)

        # compute standard source separation loss
        loss_speech = self.loss_fn(estimate_speech, target_speech)
        loss_noise = self.loss_fn(estimate_noise, target_noise)
        return (loss_speech + loss_noise)

class SystemContrastive(System):

    def training_step(self, batch, batch_idx):
        # unpack batch
        (mixture_1, target_speech_1, target_noise_1,
        mixture_2, target_speech_2, target_noise_2, labels) = batch

        # denoise left input
        estimate_speech_1, estimate_noise_1 = self(mixture_1)

        # denoise right input
        estimate_speech_2, estimate_noise_2 = self(mixture_2)

        # compute contrastive source separation loss
        loss_speech = loss_contrastive(
            estimate_speech_1, target_speech_1,
            estimate_speech_2, target_speech_2,
            labels=labels,
            distance_fn=self.loss_fn,
        )
        loss_noise = loss_contrastive(
            estimate_noise_1, target_noise_1,
            estimate_noise_2, target_noise_2,
            labels=~labels,
            distance_fn=self.loss_fn,
        )
        return (loss_speech + loss_noise)

The difference in these two systems is what type of batches of data they expect. The contrastive model expects a double-audio batch (as evident in my training_step).

So I guess now you could instantiate the experiment setup in a couple of different ways, for example:

model = NetworkRNN(128, 2)
system = System(model, 'sdr', 1e-4)  # will train an RNN in a unimodal fashion
system = SystemContrastive(model, 'sdr', 1e-4)  # will train an RNN in a contrastive fashion

model = NetworkCTN()
system = System(model, 'sdr', 1e-4)  # will train ConvTasNet in a unimodal fashion
system = SystemContrastive(model, 'sdr', 1e-4)  # will train ConvTasNet in a contrastive fashion

If the system is a SystemContrastive, then it should be associated with a set of contrastive dataloaders. But if it's the regular setup, then I need a set of regular dataloaders.

My ultimate goal is to compare standard source separation loss against a custom contrastive loss that I'm making up. I'm just not too sure if this is the right way to organize my experiment. Or if there's a benefit to using asteroid.engine.system.System here instead of my generic System class.

mpariente commented 3 years ago

Thanks for the snippets. I think your subclasses are completely fine, it's not too much code so I wouldn't spend much time trying to reuse Asteroid's System which is not meant to be general but just work well in classic use cases.

I'm curious about your results, do let us know how it works !

mpariente commented 3 years ago

Have you succeeded in the end?

actuallyaswin commented 3 years ago

Closing this issue because I ended up not using PyTorch Lightning or the System wrapper, as @mpariente said, wasn't quite applicable for my use case. 😌