lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.33k stars 249 forks source link

Use a pretrained model as a discriminator (and for feature maps) #177

Closed turian closed 1 year ago

turian commented 1 year ago

Just a speculative idea, that I've been playing around with internally:

EfficientAT, particularly the "mn40_as_ext" model, is a very high-performing pretrained audio embedding model. It's an EfficientNet CNN vision model distilled from PaSST which had the highest performing scores on FSD50K (general audio) class prediction in the HEAR Benchmark. This group's code is also very easy to use and they are very responsive on github.

The bleeding edge idea I propose as an option is that EfficientAT (with default model "mn40_as_ext") is used as a discriminator, i.e. only one class prediction. The only real change need is this:


    def _forward_impl(self, x: Tensor, return_fmaps: bool = False) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
        fmaps = []

        for i, layer in enumerate(self.features):
            x = layer(x)
            if return_fmaps:
                fmaps.append(x)

        features = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
        x = self.classifier(x).squeeze()

        if features.dim() == 1 and x.dim() == 1:
            # squeezed batch dimension
            features = features.unsqueeze(0)
            x = x.unsqueeze(0)

        if return_fmaps:
            return x, fmaps
        else:
            return x, features

    def forward(self, x: Tensor) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
        return self._forward_impl(x)

With that, you can also retrieve the feature maps for learning feature matching in the generator.

In fact, my biggest compunction about this approach is that the loss drops so. damn. fast. because it's pretrained. Thus making it hard for the generator to catch up, unless it has a very very slow learning rate or TTUR or similar.

lucidrains commented 1 year ago

@turian yea sure, i can loosen the restrictions and allow the discriminator to be passed in

lucidrains commented 1 year ago

@turian ok, you can pass it into the Soundstream as stft_discriminator, once you get the authors of that network to return feature maps

turian commented 1 year ago

@lucidrains cool but shouldn't that be called efficient_at_discriminator?

turian commented 1 year ago

Sure. If you like, I can share the wrapper that I wrote.

VERY IMPORTANT:


        # This might be slow to backprop through
        if sample_rate == 32000:
            self.resampler = torch.nn.Identity()
        else:
            self.resampler = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=32000
            )

Since EfficientAT is trained on 32000 KHz audio. torch.resample passes gradients back, btw.

lucidrains commented 1 year ago

@lucidrains cool but shouldn't that be called efficient_at_discriminator?

the idea is you can instantiate with any discriminator and just pass it in