RetroCirce / HTS-Audio-Transformer

The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"
https://arxiv.org/abs/2202.00874
MIT License
341 stars 62 forks source link

About shape of input wav #40

Closed wangqian621 closed 1 year ago

wangqian621 commented 1 year ago

Hi,thanks for your paper and code~ 我想做类似Mask audio prediction的任务,需要利用帧级别的输入,但是我发现输入的维度是[batch_size,1,256,256],patch embedding是[Batch_size,4096,96],而预测输出的fine_grained_embedding:torch.Size([Batch_size, 1024, 768]),我想输入对应上输出的1024.请问论文中提到的pad to 1024frame在哪里呢? image

RetroCirce commented 1 year ago

你好! pad或者interpolation是在这个地方做的:

首先在这里 https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/model/htsat.py#L833C16-L833C44 这个地方是训练数据(以及绝大部分数据进入模型的入口)经过一个叫做reshape_wav2img的函数,这个地方主要是我们想通过这个函数把spec转成接近图片的格式,因为我们用到的某些函数支持的是图片类型(基本上H=W)的输入,同时,如果把数据reshape成图片的shape,可以让我们使用imagenet pretrain的weight -->毕竟他们一开始都是在图片上预训练的

这个函数具体的做法在这里 https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/model/htsat.py#L729C4-L744C17 你可以看到,我们定义了target_F 和 target_T,你可以看看他具体的数值,就是64和1024 x进入这个函数其实他的shape是(1000+1,64)这个是STFT+mel在我们的参数下给出的结果 我们会把1000 interpolate 成1024,其实如果frequency bin不是64的话,我们也会把frequency bin interpolate成64. 使用interpolate的原因是因为我们觉得这样子不用补0,其实测试会发现补0也是没问题的,因为24帧太短了,模型不会因为这个改变太多的训练和测试结果。

我们在文章中确实写了padding,实际上我们最后用的是interpolation,这些代码也是经过了几个迭代的修改 -->你可以发现这个代码有很多分支和注释,这些都是我们测试后发现没有那么多用处的东西,就被废弃了(如果对你造成了阅读上的一些麻烦,深表歉意)。我会考虑在之后的版本中更新一个更简洁的版本。