YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.17k stars 221 forks source link

Huggingface-compatible ImageNet pre-trained weights #109

Closed penguinwang96825 closed 1 year ago

penguinwang96825 commented 1 year ago

Absolutely brilliant work! I was just wondering if it’s possible you can host a huggingface-compatible ImageNet pre-trained weights of the AST model on the Huggingface Hub? Rather than uploading the AudioSet fine-tuned version, this could help the community to have a good starting point for this project.

YuanGongND commented 1 year ago

hi there,

I am not sure if you mean this? https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer

It is implemented by HF staff, not me, but I think they have pretraned AST checkpoints (but not SSAST checkpoints)

-Yuan

penguinwang96825 commented 1 year ago

Hi Yuan,

Thank you for your prompt response. I appreciate it.

I would like to clarify a point from your paper. In the paper, you mentioned that the AST model is initialised with ImageNet weights for cross-modality transfer. I was wondering if it would be possible for you to upload these weights to the Hugging Face Model Hub? By doing so, it would enable us to utilise the convenient from_pretrained method provided by the transformers library.

Having these weights readily available on the HF Hub would greatly facilitate the replication of the results outlined in your paper. Researchers and practitioners would be able to effortlessly load the weights and replicate the same outcomes, contributing to the ease of comparison and validation.

Without this compatibility between the code in your repository and the model weights hosted on the HF Hub, there could be potential challenges for others attempting to reproduce your work.

Thank you for considering my request. Your support in this matter would undoubtedly enhance the accessibility and reproducibility of your research.

Looking forward to your feedback.

Warm regards, Yang

YuanGongND commented 1 year ago

hi there,

We did not train ImageNet weights by ourselves, but instead using a popular package timm. You can dump the weights from timm, specifically, we use timm==0.4.5, and use the following code to initialize the model:

https://github.com/YuanGongND/ast/blob/31088be8a3f6ef96416145c4b8d43c81f99eba7a/src/models/ast_models.py#L59-L69

Since timm is a very commonly used package, you may be able to find an HF checkpoint. Unfortunately, we do not have time to do the conversion.

-Yuan

penguinwang96825 commented 1 year ago

For future reference, I uploaded the Huggingface-compatible ImageNet pre-trained weights (vit_deit_base_distilled_patch16_384) of the AST model.

from transformers import (
    AutoFeatureExtractor, 
    AutoConfig, 
    AutoModelForAudioClassification
)

model_id = 'yangwang825/ast-imagenet-10-10'
feature_extractor = AutoFeatureExtractor.from_pretrained(
    model_id, 
    trust_remote_code=True
)
model = AutoModelForAudioClassification.from_pretrained(
    model_id, 
    trust_remote_code=True, 
)
YuanGongND commented 1 year ago

@penguinwang96825

Thanks so much for your contribution!