ZhuoYulang / CIF-MMIN

MIT License
19 stars 2 forks source link

Did the model see the labels during the pre-training process? #1

Open Cb1ock opened 6 months ago

Cb1ock commented 6 months ago

Hi author, thank you for your excellent work.

I have some questions about the loss computation in the backward function of the UttSelfSuperviseModel class in the models/utt_self_supervise_model.py file. I noticed that when you compute the loss, you have four losses, one of which is loss_CE computed by:

if self.opt.corpus_name != 'MOSI':
    self.criterion_ce = torch.nn.CrossEntropyLoss()
else:
    self.criterion_ce = torch.nn.MSELoss()
if self.opt.corpus_name != 'MOSI':
    self.loss_CE = self.criterion_ce(self.logits, self.label)
else:
    self.loss_CE = self.criterion_ce(self.logits, self.label)
loss = self.loss_CE + self.loss_TV + self.loss_TA + self.loss_VA

So, here is my question: Did the model see the label during the pre-training process? I realized that when you wrote the dataset and dataloader in the data/multimodal_dataset.py file:

def __getitem__(self, index):
    int2name = self.int2name[index]
    if self.corpus_name == 'IEMOCAP':
        int2name = int2name[0].decode()
    label = torch.tensor(self.label[index])
    # print('dataset-label is:', label)
    # process A_feat
    A_feat = torch.from_numpy(self.all_A[int2name][()]).float()
    if self.A_type == 'comparE':
        A_feat = self.normalize_on_utt(A_feat) if self.norm_method == 'utt' else self.normalize_on_trn(A_feat)
    # process V_feat 
    V_feat = torch.from_numpy(self.all_V[int2name][()]).float()
    # process L_feat
    L_feat = torch.from_numpy(self.all_L[int2name][()]).float()
    return {
        'A_feat': A_feat, 
        'V_feat': V_feat,
        'L_feat': L_feat,
        'label': label,
        'int2name': int2name
    }

The label was obtained from the dataset using this code:

label_path = os.path.join(config['target_root'], f'{cvNo}', f"{set_name}_label.npy")

Look forward to your reply!

ZhuoYulang commented 6 months ago

I apologize for not getting back to you in time. Sure, the model saw the label during the pre-training process. The model need to used labels to supervise all the encoders to learn during the pre-training process. ---- Replied Message ---- FromHao @.>Date3/29/2024 @.>@.>Subject[ZhuoYulang/CIF-MMIN] Did the model see the labels during the pre-training process? (Issue #1) Hi author, thank you for your excellent work. I have some questions about the loss computation in the backward function of the UttSelfSuperviseModel class in the models/utt_self_supervise_model.py file. I noticed that when you compute the loss, you have four losses, one of which is loss_CE computed by: if self.opt.corpus_name != 'MOSI': self.criterion_ce = torch.nn.CrossEntropyLoss() else: self.criterion_ce = torch.nn.MSELoss() if self.opt.corpus_name != 'MOSI': self.loss_CE = self.criterion_ce(self.logits, self.label) else: self.loss_CE = self.criterion_ce(self.logits, self.label) loss = self.loss_CE + self.loss_TV + self.loss_TA + self.loss_VA So, here is my question: Did the model see the label during the pre-training process? I realized that when you wrote the dataset and dataloader in the data/multimodal_dataset.py file: def getitem(self, index): int2name = self.int2name[index] if self.corpus_name == 'IEMOCAP': int2name = int2name[0].decode() label = torch.tensor(self.label[index]) # print('dataset-label is:', label) # process A_feat A_feat = torch.from_numpy(self.all_A[int2name][()]).float() if self.A_type == 'comparE': A_feat = self.normalize_on_utt(A_feat) if self.norm_method == 'utt' else self.normalize_on_trn(A_feat) # process V_feat V_feat = torch.from_numpy(self.all_V[int2name][()]).float() # process L_feat L_feat = torch.from_numpy(self.all_L[int2name][()]).float() return { 'A_feat': A_feat, 'V_feat': V_feat, 'L_feat': L_feat, 'label': label, 'int2name': int2name } The label was obtained from the dataset using this code: label_path = os.path.join(config['target_root'], f'{cvNo}', f"{set_name}_label.npy") Look forward to your reply! — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you are subscribed to this thread.Message ID: @.>