YuanGongND / ssast

Code for the AAAI 2022 paper "SSAST: Self-Supervised Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
365 stars 61 forks source link

Are the model weights loaded into the ASTModel with stochasticity? #30

Open michaelschwob opened 9 months ago

michaelschwob commented 9 months ago

I have a best_audio_model.pth folder that is the output of a pretrained SSAST model. My specific .pth file can be found here for reproducibility. Note that this is the result of a tiny model fit to 5 of the classes found in the ESC-50 data set; I chose 5 classes to keep the computation and plotting simple.

The following code (MWE) reveals that the weights are loaded differently in two model initializations, despite the parameters and data being the same. Specifically, the mlp_head.1.weight (and potentially more) appear to be different.

import torch
import os
import ast
import pickle
import sys
import time
sys.path.insert(0, 'ssast/src/models')
from ast_models import ASTModel
os.chdir('ssast/src')
import dataloader
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json

## to minimize stochasticity in model evaluation
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## initialize the model (ensure parameters match those used during training)
model_1 = ASTModel(label_dim = 5,fshape=16, tshape=16, fstride=16, tstride=16,
                       input_fdim=128, input_tdim=512, model_size='tiny', pretrain_stage=False, load_pretrained_mdl_path='./best_audio_model.pth')

## set the model to evaluation mode
model_1.eval()

model_2 = ASTModel(label_dim = 5,fshape=16, tshape=16, fstride=16, tstride=16,
                       input_fdim=128, input_tdim=512, model_size='tiny', pretrain_stage=False, load_pretrained_mdl_path='./best_audio_model.pth')

## set the model to evaluation mode
model_2.eval()

###
### Model Comparison
###

def are_models_identical(model_1, model_2):
    model_1_state_dict = model_1.state_dict()
    model_2_state_dict = model_2.state_dict()

    for p in model_1_state_dict:
        if p not in model_2_state_dict:
            print(f"Missing parameter {p} in model 2")
            return False
        if not torch.equal(model_1_state_dict[p], model_2_state_dict[p]):
            print(f"Difference found in parameter: {p}")
            return False

    # Check if model_2 has extra parameters not in model_1
    for p in model_2_state_dict:
        if p not in model_1_state_dict:
            print(f"Extra parameter {p} in model 2 not found in model 1")
            return False

    return True

# Example usage
identical = are_models_identical(model_1, model_2)
print(f"Models are {'identical' if identical else 'different'}")

print("Model 1 'mlp_head.1.weight' mean:", model_1.state_dict()['mlp_head.1.weight'].mean())
print("Model 1 'mlp_head.1.weight' std:", model_1.state_dict()['mlp_head.1.weight'].std())

print("Model 2 'mlp_head.1.weight' mean:", model_2.state_dict()['mlp_head.1.weight'].mean())
print("Model 2 'mlp_head.1.weight' std:", model_2.state_dict()['mlp_head.1.weight'].std())

Forgive me if my assumption is incorrect, but I would assume that the weights should be identical for model_1 and model_2. Additionally, I would assume that model_1(inputs, 'ft_avgtok') and model_2(inputs, 'ft_avgtok') would give the same output for identical inputs, which it does not. Is there something simple that I'm missing here?

Once again, thank you for all of your help, Yuan. I greatly appreciate it and am sorry to keeping bothering you.

michaelschwob commented 9 months ago

I found the following code in the ASTModel definition to be the source of differentiation. This leaves me wondering if the

            # mlp head for fine-tuning ; this is different every time
            self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim),
                                          nn.Linear(self.original_embedding_dim, label_dim))

            .
            .
            .            

            # patch split stride generally should be different for pretraining and fine-tuning, as patch split overlapping is only used in finetuning
            # during pretraining, p_fshape = p_fstride and p_tshape = p_tstride
            if fstride != p_fshape or tstride != p_tshape:
                # initialize a new patch embedding layer with desired new stride ; this also initializes a new layer with random parameters
                new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride))
                # but the weights of patch embedding layer is still got from the pretrained models
                new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
                new_proj.bias = self.v.patch_embed.proj.bias
                self.v.patch_embed.proj = new_proj

            new_pos_embed = self.v.pos_embed[:, self.cls_token_num:, :].detach().reshape(1, p_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, p_f_dim, p_t_dim)
            # cut or interpolate the positional embedding
            if t_dim < p_t_dim:
                new_pos_embed = new_pos_embed[:, :, :, int(p_t_dim/2) - int(t_dim / 2): int(p_t_dim/2) - int(t_dim / 2) + t_dim]
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(8, t_dim), mode='bilinear')
            if f_dim < p_f_dim:
                new_pos_embed = new_pos_embed[:, :, int(p_f_dim/2) - int(f_dim / 2): int(p_f_dim/2) - int(f_dim / 2) + t_dim, :]
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')

            new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1, 2)
            self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :self.cls_token_num, :].detach(), new_pos_embed], dim=1))

Whenever we initialize a new layer (via nn.Sequential() or torch.nn.Conv2d() in the above code snippet), we are adding a random set of weights rather than leveraging the pretrained model.

Is this how it should function, or should initialization calls for ASTModel using the same pretrained model result in the identical model?

YuanGongND commented 7 months ago

Thanks for reporting this. This may be valid, the code has some problem in loading a pretrained SSAST model.