YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.06k stars 203 forks source link

Regression Task #110

Open BaMarcy opened 10 months ago

BaMarcy commented 10 months ago

Hello,

I'm interested in knowing whether this model can be utilized for regression tasks. From my analysis of the architecture, it appears that the model incorporates nn.Linear() towards the final layer, which leads me to believe that regression tasks might be supported. However, I'd greatly appreciate some clarification to ensure my understanding is accurate.

Thanks!

YuanGongND commented 10 months ago

hi there,

Yes, AST can do regression, you don't even need to change the model architecture, just set label_dim=1 when you instantiate the model, so the output would be a single value.

https://github.com/YuanGongND/ast/blob/31088be8a3f6ef96416145c4b8d43c81f99eba7a/src/models/ast_models.py#L47

You however need to change the training pipeline to use a different loss such as MSE, etc.

I cannot guarantee anything beyond "the provided code can reproduce the results shown in the paper". But in my opinion, AST should at least get similar regression results as other models.

-Yuan