fschmid56 / EfficientAT

This repository aims at providing efficient CNNs for Audio Tagging. We provide AudioSet pre-trained models ready for downstream training and extraction of audio embeddings.
MIT License
218 stars 41 forks source link

Add feature maps #11

Closed turian closed 1 year ago

turian commented 1 year ago

In https://github.com/lucidrains/audiolm-pytorch/issues/177 I describe an approach for using EfficientAT as a discriminator for GAN training.

The only code change needed is this:


    def _forward_impl(self, x: Tensor, return_fmaps: bool = False) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
        fmaps = []

        for i, layer in enumerate(self.features):
            x = layer(x)
            if return_fmaps:
                fmaps.append(x)

        features = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
        x = self.classifier(x).squeeze()

        if features.dim() == 1 and x.dim() == 1:
            # squeezed batch dimension
            features = features.unsqueeze(0)
            x = x.unsqueeze(0)

        if return_fmaps:
            return x, fmaps
        else:
            return x, features

    def forward(self, x: Tensor) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
        return self._forward_impl(x)

to optionally return feature maps for learning feature matching in the generator.