YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.17k stars 221 forks source link

Regression Task #110

Open BaMarcy opened 1 year ago

BaMarcy commented 1 year ago

Hello,

I'm interested in knowing whether this model can be utilized for regression tasks. From my analysis of the architecture, it appears that the model incorporates nn.Linear() towards the final layer, which leads me to believe that regression tasks might be supported. However, I'd greatly appreciate some clarification to ensure my understanding is accurate.

Thanks!

YuanGongND commented 1 year ago

hi there,

Yes, AST can do regression, you don't even need to change the model architecture, just set label_dim=1 when you instantiate the model, so the output would be a single value.

https://github.com/YuanGongND/ast/blob/31088be8a3f6ef96416145c4b8d43c81f99eba7a/src/models/ast_models.py#L47

You however need to change the training pipeline to use a different loss such as MSE, etc.

I cannot guarantee anything beyond "the provided code can reproduce the results shown in the paper". But in my opinion, AST should at least get similar regression results as other models.

-Yuan