microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
19.57k stars 2.5k forks source link

[BEATs] Evaluation results inconsistent. #1012

Closed RicherMans closed 1 year ago

RicherMans commented 1 year ago

Hey there, I check out your BEATs model recently for some work of mine and found that it does not perform as expected.

So I downloaded out all of the checkpoints provided and changed their respective labels to the "standard" ones from the class_label_indices.csv, with this mapping:

mappings = [
            134, 84, 137, 474, 506, 141, 140, 143, 138, 462, 139, 459, 142,
            519, 509, 194, 192, 189, 9, 10, 0, 378, 413, 467, 151, 171, 164,
            243, 162, 239, 172, 165, 168, 161, 14, 433, 301, 304, 360, 169, 92,
            166, 120, 119, 332, 330, 328, 329, 505, 27, 11, 507, 445, 368, 489,
            498, 346, 453, 512, 95, 86, 72, 26, 281, 32, 247, 481, 218, 354,
            357, 230, 376, 398, 510, 426, 435, 434, 401, 358, 427, 430, 24,
            300, 317, 298, 446, 64, 522, 299, 349, 62, 321, 308, 337, 68, 222,
            431, 484, 326, 526, 215, 159, 163, 371, 370, 288, 386, 525, 220,
            223, 392, 469, 277, 353, 312, 313, 315, 265, 444, 410, 404, 170,
            58, 508, 292, 381, 339, 285, 387, 449, 488, 307, 348, 456, 60, 472,
            255, 343, 351, 112, 514, 48, 126, 127, 389, 390, 487, 504, 73, 74,
            80, 7, 219, 133, 132, 19, 47, 122, 45, 236, 274, 241, 98, 100, 99,
            177, 176, 440, 470, 394, 46, 111, 113, 359, 155, 156, 208, 195,
            325, 322, 396, 94, 428, 193, 191, 53, 88, 461, 374, 331, 492, 33,
            283, 294, 302, 185, 197, 188, 69, 70, 439, 437, 145, 278, 463, 464,
            306, 412, 513, 438, 443, 242, 284, 49, 268, 411, 117, 118, 180,
            181, 179, 82, 90, 91, 263, 336, 340, 335, 415, 186, 184, 258, 158,
            225, 65, 500, 135, 515, 269, 493, 257, 256, 21, 436, 18, 20, 297,
            251, 226, 420, 271, 124, 131, 129, 379, 125, 466, 110, 397, 369,
            238, 254, 405, 409, 494, 30, 454, 97, 96, 516, 496, 71, 187, 324,
            323, 280, 432, 85, 477, 224, 403, 296, 206, 205, 6, 42, 266, 89,
            149, 249, 28, 221, 167, 318, 363, 272, 270, 475, 382, 104, 106,
            311, 483, 264, 209, 384, 347, 212, 87, 482, 41, 229, 231, 400, 276,
            517, 388, 425, 418, 424, 34, 107, 210, 39, 123, 196, 232, 244, 175,
            275, 316, 341, 393, 361, 12, 81, 83, 279, 460, 75, 486, 352, 344,
            391, 8, 303, 373, 174, 102, 103, 253, 246, 144, 423, 476, 447, 451,
            54, 408, 448, 452, 282, 429, 502, 345, 520, 422, 327, 310, 116,
            115, 108, 365, 364, 146, 182, 233, 154, 157, 152, 153, 442, 93, 56,
            147, 295, 305, 495, 183, 356, 273, 52, 338, 40, 227, 23, 499, 190,
            240, 441, 79, 402, 248, 421, 4, 51, 342, 366, 200, 362, 38, 31,
            201, 15, 491, 406, 395, 211, 333, 259, 485, 160, 468, 105, 380,
            252, 458, 101, 261, 3, 37, 25, 377, 13, 289, 287, 417, 416, 35,
            150, 479, 399, 350, 29, 521, 286, 450, 234, 503, 63, 523, 260, 497,
            202, 178, 245, 57, 407, 109, 121, 524, 319, 130, 199, 237, 419, 44,
            320, 213, 148, 61, 290, 16, 17, 355, 375, 309, 43, 114, 203, 77,
            214, 173, 67, 128, 66, 235, 76, 78, 501, 385, 55, 217, 36, 216,
            293, 473, 478, 22, 490, 367, 414, 207, 198, 471, 262, 59, 204, 267,
            480, 250, 518, 465, 228, 457, 1, 334, 372, 511, 455, 5, 314, 2,
            291, 383, 136, 50]

I pretty much followed the description of the BEATs model and compute the logits the following way:

class BeatsWrapper(torch.nn.Module):

    def __init__(self, checkpoint= Path(__file__).parent / 'BEATs_iter1_finetuned_on_AS2M_cpt1.pt',
            ):
        super().__init__()
        # load the fine-tuned checkpoints
        checkpoint = torch.load(
                checkpoint)
        cfg = BEATsConfig(checkpoint['cfg'])
        mdl_impl = BEATs(cfg)
        mdl_impl.load_state_dict(checkpoint['model'])

    def forward(self, x):
        padding_mask = torch.zeros_like(x).bool()
        probs = self.mdl_impl.extract_features(
            x, padding_mask=padding_mask)[0]
        tar_prob = torch.zeros_like(probs)
        tar_prob[...,mappings] = probs 
        return tar_prob

Then I ran my evaluation on my own (18229 samples) Audioset eval set. Since Audioset is generally hard to come by and everybody has a different evaluation set, I pretty much always assume that some results are more lucky than others. It especially depends on when you downloaded the dataset, since over time more and more samples are missing, leading to a large performance difference, especially for rare classes. When I ran my evaluation I first couldn't reproduce your results, thus I thought that maybe my dataset is very different from "all the others". Therefore, I double-checked with other popular open-source models AST and PANNs and obtained the following:

Model Official mAP mAP on my eval set
AST 45.9 45.8
PANNs CNN14 43.8 43.7
BEATs Iter 1 47.9 45.4
BEATS Iter 2 48.0 44.1
BEATS Iter 3 47.9 43.3
BEATs Iter 3+ 48.6 46.7

So It seems to me that my evaluation dataset is similar to the one in PANNs and AST. Did I make a mistake here or is your evaluation set maybe a different one from most others?

Kind regards, Heinrich Dinkel

Sanyuan-Chen commented 1 year ago

@RicherMans Hi, it seems that you didn't set the model to eval() mode (i.e. model.eval()) during inference. The training mode enables the dropout and layerdrop, which would hurt the inference performance.

RicherMans commented 1 year ago

Hey @Sanyuan-Chen , nope I did use .eval(), otherwise the baseline results of AST and HT-AST would be equally wrong. I ran all these experiments at least 3 times and all above reported results were consistent. Sorry that I didn't provide the entire codeset, but each model is first set to .eval() for evaluation. I also just to double checked and edited the above code to:

class BeatsWrapper(torch.nn.Module):

    def __init__(self, checkpoint= Path(__file__).parent / 'BEATs_iter1_finetuned_on_AS2M_cpt1.pt',
            ):
        super().__init__()
        # load the fine-tuned checkpoints
        checkpoint = torch.load(
                checkpoint)
        cfg = BEATsConfig(checkpoint['cfg'])
        mdl_impl = BEATs(cfg)
        mdl_impl.load_state_dict(checkpoint['model'])
        self.mdl_impl = mdl_impl
        self.mdl_impl.eval()

    def forward(self, x):
        padding_mask = torch.zeros_like(x).bool()
        probs = self.mdl_impl.extract_features(
            x, padding_mask=padding_mask)[0]
        tar_prob = torch.zeros_like(probs)
        tar_prob[...,mappings] = probs 
        return tar_prob

After the changes, the results remain the same:

image

Then I checked for some other possible problems, for example the padding mask, i.e., if that padding (that sometimes exists, because not all samples are 10s) has an impact. With padding we get 45.45, when changing the padding mask accordingly, we get 45.42, not a very noticeable drop.

I currently believe that your evaluation split is a much easier one than most people use in other works. I personally blame here Audioset for this behaviour, since as I said above, everybody has a different eval split and training split, which is not very useful for reproduction.

Let me ask if you have tried to for example evaluate AST on your own eval split? I suspect that it would be higher than your iter 0 model. If so please adjust the results accordingly.

Sanyuan-Chen commented 1 year ago

Hi @RicherMans

As you suggested, I just tried to evaluate AST on our eval split, and the result is mAP: 0.45886105276409356, which is equal to their official mAP score. So I think our eval set is consistent with others.

During the evaluation, I noticed that not all the eval audio has a length of exactly 10s, and some previous works first crop/pad the input audio to a fixed length, then feed it to the model. Instead, our model consumes the audio of the original length, and a padding mask should be provided if the padding operation is applied.

Besides, I am currently preparing the files of a) my eval audio list (name + length) b) the predicted probs by BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1 model c) the target probs. Please check the eval audios that are available on your side, and see if your eval code could predict the same probs.

RicherMans commented 1 year ago

Hey @Sanyuan-Chen , so as I mentioned before, I already did check for padding, using 1. the masking operation in BEATs (with little impact) and 2. using batch-size = 1, which avoids padding. Both results are consistent from my side.

As you suggested, I just tried to evaluate AST on our eval split, and the result is mAP: 0.45886105276409356, which is equal to their official mAP score. So I think our eval set is consistent with others.

So now it seems weird that my result is so far off yours, I am still thinking is there some parameter that I haven't set properly? Are my mappings maybe not 100% correct here?

Besides, I am currently preparing the files of a) my eval audio list (name + length) b) the predicted probs by BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1 model c) the target probs. Please check the eval audios that are available on your side, and see if your eval code could predict the same probs.

So I also just ran for that model my outputs ( here I used my mapping to the "default" class label indices and applied sigmoid).

Here are the outputs for BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1 from my side:

wget https://transfer.sh/zb8Xyq/eval_beats_iter3.pt
python3 -c "import torch; torch.load('eval_beats_iter3.pt')"

I also checked for some discrepancies between versions, on pytorch 1.8.2 I got: 46.72 for Iter3 Plus

On pytorch 1.13 I got: 46.72 for Iter3 Plus

Sanyuan-Chen commented 1 year ago

Hi @RicherMans

I checked the output of your first audio'--4gqARaEJE', and found it is different from mine.

Here is my inference code:

from BEATs import BEATs, BEATsConfig
import torch
import soundfile as sf

checkpoint='BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt'
checkpoint = torch.load(checkpoint)
cfg = BEATsConfig(checkpoint['cfg'])
mdl_impl = BEATs(cfg)
mdl_impl.load_state_dict(checkpoint['model'])
mdl_impl.eval()

audio, sr = sf.read('--4gqARaEJE.wav')
audio = torch.from_numpy(audio).float().unsqueeze(0)
padding_mask = torch.zeros_like(audio).bool()
probs = mdl_impl.extract_features(audio, padding_mask=padding_mask)[0]

print(probs[0].topk(10))

And it outputs:

torch.return_types.topk(
values=tensor([0.9541, 0.9004, 0.8234, 0.3684, 0.2184, 0.1404, 0.1353, 0.0659, 0.0442,
        0.0259], grad_fn=<TopkBackward>),
indices=tensor([ 61, 148, 149,  81, 480, 157, 150, 487, 431, 400]))

After label mapping, the output is:

torch.return_types.topk(
values=tensor([0.9541, 0.9004, 0.8234, 0.3684, 0.2184, 0.1404, 0.1353, 0.0659, 0.0442,
        0.0259], grad_fn=<TopkBackward>),
indices=tensor([ 72,  73,  74,  24,  77, 122,  80,  76,  25,  79]))

And the output from your file is:

torch.return_types.topk(
values=tensor([0.9575, 0.8933, 0.7965, 0.3284, 0.1949, 0.1224, 0.1018, 0.0490, 0.0338,
        0.0250]),
indices=tensor([ 72,  73,  74,  24,  77, 122,  80,  76,  25,  79]))

Here is my input wav

My torch versions:

torch                        1.8.1+cu111
torchaudio                   0.8.1
RicherMans commented 1 year ago

Hey @Sanyuan-Chen , thanks for the quick reply!

And thanks for the provided file.

So I checked now with my provided file and yes they are indeed different. First, lengths do not match well, ie., yours is 10.00 seconds long, but mine is (due to sampling) 10.11. While I don't think that 0.11 seconds make a different, they maybe do if you have some fixed-length parameters etc. But usually that additional length is "eaten" by the patch embedding.

Second and much more importantly, your files are filtered with a low-pass. For example my file's spectrogram is:

image

and yours:

image

It is not very noticeable, but your maximal frequency is below 8k, seems to be around 7.8k, while I use the full 8k.

Thus I can currently conclude two things:

  1. BEATs evaluates maybe twice for a clip longer than 10.00 s? Is there somewhere a parameter for this?
  2. Your evaluation data indeed differs from mine, but what is weird about the situation is that other models are not affected by the difference in audio, but BEATs is heavily affected.

Well, so thanks for the quick reply!

I dunno how to go on from this, at least in my publications I would currently place an asterix to the results of BEATs due to the discrepancy.

Both files are :

difference_for_eval_audioset.zip

Sanyuan-Chen commented 1 year ago

Hi @RicherMans

Thanks for the detailed analysis!

  1. BEATs deal with all the clips in the same way, regardless of the durations.
  2. It is also weird to me. I'm not sure if it is caused by different resampling methods. I download the original audio from Youtube with the youtube-dl toolkit, and then resample the clips to 16k sr with sox raw.wav -G -c 1 -r 16000 raw_c1_sr16k.wav.
RicherMans commented 1 year ago

Hey, so thanks to you too for the quick and insightful analysis of your model.

It is also weird to me. I'm not sure if it is caused by different resampling methods. I download the original audio from Youtube with the youtube-dl toolkit, and then resample the clips to 16k sr with sox raw.wav -G -c 1 -r 16000 raw_c1_sr16k.wav.

Yeah that can be an issue here, since yt-dlp and youtube-dl do not always download the same audio. You can always adjust the original sample quality with i.e., -F (for checking) and -f to select some quality. Thus we might get different quality samples.

So to my end I'll close the issue, since I think that we both have different evaluation sets.

From our discussion I think that it's weird that we both get same results for AST on both datasets, but results for BEATs differ greatly. I can only conclude from the experiments that BEATs is quite overfitted towards specific sample quality, not like AST or some other models.

Thanks for everything! Kind regards, Heinrich Dinkel

ShadowVicky commented 1 year ago

@RicherMans Hey ! Can you tell me how can i run beats model in my own dataset ?

RicherMans commented 1 year ago

Hey @ShadowVicky , Nope I only use the models for evaluation. So far me and other researchers I talked to couldn't replicate the results.

ShadowVicky commented 1 year ago

@RicherMans we need also this for the evolution purpose. So,it will be better for us if you help us.

RicherMans commented 11 months ago

Hey @ShadowVicky ,

does anything in this thread help then? I think I made my own code how to evaluate the model pretty clear here:

Hey @Sanyuan-Chen , nope I did use .eval(), otherwise the baseline results of AST and HT-AST would be equally wrong. I ran all these experiments at least 3 times and all above reported results were consistent. Sorry that I didn't provide the entire codeset, but each model is first set to .eval() for evaluation. I also just to double checked and edited the above code to:

class BeatsWrapper(torch.nn.Module):

    def __init__(self, checkpoint= Path(__file__).parent / 'BEATs_iter1_finetuned_on_AS2M_cpt1.pt',
            ):
        super().__init__()
        # load the fine-tuned checkpoints
        checkpoint = torch.load(
                checkpoint)
        cfg = BEATsConfig(checkpoint['cfg'])
        mdl_impl = BEATs(cfg)
        mdl_impl.load_state_dict(checkpoint['model'])
        self.mdl_impl = mdl_impl
        self.mdl_impl.eval()

    def forward(self, x):
        padding_mask = torch.zeros_like(x).bool()
        probs = self.mdl_impl.extract_features(
            x, padding_mask=padding_mask)[0]
        tar_prob = torch.zeros_like(probs)
        tar_prob[...,mappings] = probs 
        return tar_prob

After the changes, the results remain the same:

image

Then I checked for some other possible problems, for example the padding mask, i.e., if that padding (that sometimes exists, because not all samples are 10s) has an impact. With padding we get 45.45, when changing the padding mask accordingly, we get 45.42, not a very noticeable drop.

I currently believe that your evaluation split is a much easier one than most people use in other works. I personally blame here Audioset for this behaviour, since as I said above, everybody has a different eval split and training split, which is not very useful for reproduction.

Let me ask if you have tried to for example evaluate AST on your own eval split? I suspect that it would be higher than your iter 0 model. If so please adjust the results accordingly.

ByeonggeunKim commented 9 months ago

Hello team,

I would like to express my gratitude for sharing the codes related to the intriguing BEATs project. Recently, I made an attempt to replicate the results presented in your manuscript using the provided checkpoints. However, my findings align with those of @RicherMans.

I have included the specific numbers below for your references. I would greatly appreciate any advice or guidance you can offer for us.

mAP on AS-2M iter1: 45.3 iter2: 44.1 iter3: 42.9 iter3+M: 46.5

FYI, our evaluation set comprises 19.7k samples where as yours consists of 19k samples.

Warm regards, Byeonggeun.

RicherMans commented 9 months ago

Hey @ByeonggeunKim , so let me maybe explain what happens from my point of view. The authors of Beats downloaded a "CD quality" sampling rate version of Audioset with a sample rate of 32k or higher. Then they used sox to preprocess this dataset, which gives different results from i.e., ffmpeg. Sox does a reasonable thing by first low passing the signal and then subsampling, whereas ffmpeg just directly subsamples. In general the difference between these two methods should be marginal, as you can see in this post, some data in the range of 7.8 - 8k differs from each other.

For all models that I have ever trained and tested on Audioset, this difference is marginal, usually 0.1 to 0.2 points of mAP. However, Beats differs by quite a bit. The main reason is what I believe is the positional encoding of the model. Each "patch" extracted from a spectrogram is relatively dependent on each other patch within a range of 2.5s! I.e., if some patches for the high-frequency spectrum would only slightly change compared to the training regime, their positional embeddings are likely to provide wrong values.

How to fix that? With the exception of finetuning on your own dataset, there is no way to fix that directly. However, the issue would still remain the same even if you finetune, you would need to assume that during testing ll audio is sampled exactly the same as your training data (i.e., recording 32k then subsampling and so on). Thus, I would believe that this model is less useful for actual applications, where microphones, recording conditions and so on are unknown.

Kind regards, Heinrich Dinkel

ByeonggeunKim commented 9 months ago

Hi @RicherMans,

Thank you for sharing your experience and opinions :) I conducted a straightforward thing and I wanted to share my observations with you. I applied low-pass filtering to the eval set I have, limiting it to 7.8kHz. The results I obtained were notably distinct from those w/o low pass filtering. Here is a summary: iter1: 45.3 --> 47.2 iter2: 44.1 --> 47.4 iter3: 42.9 --> 47.2 iter3+M: 46.5 --> 48.0

Thus, it appears that the disparities came from the data as you mentioned earlier.

Best, Byeonggeun.