RicherMans / Dasheng

Source for the Interspeech 2024 Paper "Scaling up masked audio encoder learning for general audio classification"
Apache License 2.0
44 stars 3 forks source link

About downstream task on audioset classification or clap #7

Closed haidog-yaqub closed 3 months ago

haidog-yaqub commented 3 months ago

Thank you for the amazing work!

Have you explored to fine-tune this model on downstream tasks such as audioset or clap? It is excited to finally see a 1d audio foundation model which will be way friendly for down-stream sequential tasks such as sound event detection, speech phoneme recognition! Would this method still be robust as 2d methods on audioset?

RicherMans commented 3 months ago

Hey there @haidog-yaqub , thanks for your interest!

Have you explored to fine-tune this model on downstream tasks such as audioset or clap?

Surely I did fine-tune it on Audioset, using CED. Performance is alright, reaching mAP of 49.7 for the base model. I made the checkpoint available if you would like to use it:

from typing import Any, Mapping
import dasheng
import torch

class DashengAudiosetClassifier(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.dashengmodel = dasheng.dasheng_base()
        self.classifier = torch.nn.Sequential(torch.nn.LayerNorm(self.dashengmodel.embed_dim), torch.nn.Linear(self.dashengmodel.embed_dim, 527))

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
        self.dashengmodel.load_state_dict(state_dict, strict=False)
        for_classifier_dict = {}
        for k,v in state_dict.items():
            if 'outputlayer' in k:
                for_classifier_dict[k.replace('outputlayer.','')]  = v
        self.classifier.load_state_dict(for_classifier_dict)
        return self

    def forward(self, x):
        x = self.dashengmodel(x).mean(1)
        return self.classifier(x).sigmoid()

mdl = DashengAudiosetClassifier()
check = torch.hub.load_state_dict_from_url('https://zenodo.org/records/13315686/files/dasheng_audioset_mAP497.pt?download=1',map_location='cpu')
mdl.load_state_dict(check)

mdl(torch.randn(1, 16000))

I also added the above codescript to the README for future use.

It is excited to finally see a 1d audio foundation model which will be way friendly for down-stream sequential tasks such as sound event detection, speech phoneme recognition!

Yeah thanks for the comment, so obviously my baseline uses patches, but patches are you know not very "general". In HEAR, I would get ~ + 5 points of performance for FSD50k across all dasheng-models when using patches, but I loose performance for Maestro ( musical notes, frame-level) and DCASE2016 ( Sound event onset). In most cases patches are too coarse in their time resolution, which would not work for tasks as you said like phoneme recognition. Also to the best of my knowledge, ASR does not work well with inputs coarser than 40ms, that's why dasheng uses this "magic number".

Would this method still be robust as 2d methods on audioset?

Yeah I think so, at least the clip-level performance seems to be outperforming pretty much everything available except for CED-base. You can have a try yourself! :D

haidog-yaqub commented 3 months ago

Hey there @haidog-yaqub , thanks for your interest!

Have you explored to fine-tune this model on downstream tasks such as audioset or clap?

Surely I did fine-tune it on Audioset, using CED. Performance is alright, reaching mAP of 49.7 for the base model. I made the checkpoint available if you would like to use it:

from typing import Any, Mapping
import dasheng
import torch

class DashengAudiosetClassifier(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.dashengmodel = dasheng.dasheng_base()
        self.classifier = torch.nn.Sequential(torch.nn.LayerNorm(self.dashengmodel.embed_dim), torch.nn.Linear(self.dashengmodel.embed_dim, 527))

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
        self.dashengmodel.load_state_dict(state_dict, strict=False)
        for_classifier_dict = {}
        for k,v in state_dict.items():
            if 'outputlayer' in k:
                for_classifier_dict[k.replace('outputlayer.','')]  = v
        self.classifier.load_state_dict(for_classifier_dict)
        return self

    def forward(self, x):
        x = self.dashengmodel(x).mean(1)
        return self.classifier(x).sigmoid()

mdl = DashengAudiosetClassifier()
check = torch.hub.load_state_dict_from_url('https://zenodo.org/records/13315686/files/dasheng_audioset_mAP497.pt?download=1',map_location='cpu')
mdl.load_state_dict(check)

mdl(torch.randn(1, 16000))

I also added the above codescript to the README for future use.

It is excited to finally see a 1d audio foundation model which will be way friendly for down-stream sequential tasks such as sound event detection, speech phoneme recognition!

Yeah thanks for the comment, so obviously my baseline uses patches, but patches are you know not very "general". In HEAR, I would get ~ + 5 points of performance for FSD50k across all dasheng-models when using patches, but I loose performance for Maestro ( musical notes, frame-level) and DCASE2016 ( Sound event onset). In most cases patches are too coarse in their time resolution, which would not work for tasks as you said like phoneme recognition. Also to the best of my knowledge, ASR does not work well with inputs coarser than 40ms, that's why dasheng uses this "magic number".

Would this method still be robust as 2d methods on audioset?

Yeah I think so, at least the clip-level performance seems to be outperforming pretty much everything available except for CED-base. You can have a try yourself! :D

Thats great! thank you so much!