RetroCirce / HTS-Audio-Transformer

The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"
https://arxiv.org/abs/2202.00874
MIT License
344 stars 62 forks source link

Different length audio input for infer mode #7

Closed CaptainPrice12 closed 2 years ago

CaptainPrice12 commented 2 years ago

Hi, thanks for the interesting work!

I have a question about the infer mode in htsat.py. When training, the length of audio input will always be 10 seconds. When inference, the model needs to handle variable-length audio input which could be longer or shorter than 10 seconds, but for the infer mode in the htsat.py,

 if infer_mode:
            # in infer mode. we need to handle different length audio input
            frame_num = x.shape[2]
            target_T = int(self.spec_size * self.freq_ratio)
            repeat_ratio = math.floor(target_T / frame_num)
            x = x.repeat(repeats=(1,1,repeat_ratio,1))
            x = self.reshape_wav2img(x) 

What if the length of input frame_num > target_T (should be 256x4=1024 here)? If so, repeat_ratio will be 0. So how does the model process the audio longer than 10 seconds for inference here?

RetroCirce commented 2 years ago

Hi,

Thank you for your question, it is possible to handle an audio sample that has more than 'target_T'. Our method is to cut the audio sample into different slices, each of which has the length 'target_T'. For example, a 15-sec piece can be sliced as [0-10] and [5-15]. And of course you can control the overlapping. Then we send these slices into HTS-AT and get the result just like the "voting".

We did not implement this code in this repo since audio samples in our exps are always less than 10sec. But we face this question when we try to use this model in other scenario, where I actually implement this code at: https://github.com/LAION-AI/CLAP/blob/d2d5dae8ea8f1ee02ac40242418a36d1d567943a/src/open_clip/model.py#L484-L533

You can check it and adapt it into your code.

Best, Ke

CaptainPrice12 commented 2 years ago

Hi Ke, thanks for your quick reply! I have some other questions for HTS-AT.

  1. As the paper mentioned in Section 2.1.1, mel-spectrogram is split into patch windows, then patches are split inside each window. The implementation in the code actually reshapes the mel-spectrogram from a rectangle (1024x64) to a square ((1024/4)x(64x4)=256x256), following the setting of Swin-transformer for image classification with square inputs. I was wondering have you compared your patching method to another way that directly split non-overlapping patch without reshaping the mel-spectrogram to a square shape (I guess this patching method is also workable for swin-transformer)? By the way, I guess one of the reasons might be that reshaping the input to 256x256 can use the pre-trained Swin-T model on ImageNet.

  2. I am assuming that the training detail mentioned in section 3.11 in the paper is for training HTS-AT from scratch. Could you provide more training details about using pre-trained Swin-T/C24?

Thank you!

RetroCirce commented 2 years ago

Sure,

  1. Yes, as you said, the reason we reshape it to 256 x 256 is because we need to load the Swin-T model checkpoint. If you want to make it non-square shape, the most important thing is not about the shape, it is about the order. The shape of the input only affect how window attention is calculated on what patches. Of course you can change the window attention size (i.e. make it adapt to other shape while still atten. to the same patches). If you make the right order (i.e. the time-frequency-window as we propose in the paper), the result will be almost the same. Because the swin transformer still treats the input as an 1d sequence, except that it applies reshape operation in the window attention step.

  2. In table 1, HTS-AT(hc) is training from scratch, HTS-AT(hcp) is training from Swin-T/C24. When training based on Swin-T/C24, the model will convergence more faster, such as 3-5 epochs to a very good mAP (mAP > 0.4), then it needs more epochs to get better results. If you train the HTS_AT from scratch, it will takes about 15-20 epochs to get good mAP (mAP > 0.43), then it needs some epochs to convergence. Since Swin-T/C24 introduces additional data from Image, it will get better results, which makes sense.