rungjoo / CoMPM

Context Modeling with Speaker's Pre-trained Memory Tracking for Emotion Recognition in Conversation (NAACL 2022)
62 stars 14 forks source link

I wonder if the model works fine when batch is not 1? #2

Closed qftie closed 2 years ago

qftie commented 2 years ago

I don't see any operations for attention_mask, which means if the Roberta model will set all attention_mask to 1?

rungjoo commented 2 years ago

This link may be helpful. https://github.com/huggingface/transformers/blob/v4.19.3/src/transformers/models/roberta/modeling_roberta.py#L807

In RoBERTa, if attention_mask is None, the attention of all tokens is 1 by default.

qftie commented 2 years ago

If the attention mask for all tokens is 1, wouldn't there be a problem dealing with multiple sequences? (Since the padding input_ids won't be ignored when calculating attention scores)

rungjoo commented 2 years ago

Let me explain with an example.

if batch = 2 sample instance1: [u1; u2; u3]

sample instance2: [u1; u2; u3; u4]

input

As you were concerned, We do not set the attention mask for pad tokens to 1. So it seems that attention mask will be set to 1 even for padding tokens when batch_size is greater than 1. We missed this part because we set batch_size to 1 when training.

When we train the model, the batch_size is set to 1. Therefore, these problems did not occur. Also if batch_size is greater than 1, even if the mask of the pad token is set to 1, it is expected that the model can ignore the padding part while training. However, your comments will make the model train more effectively. Thanks.

ThomasDeCleen commented 1 year ago

Thank you for your detailed explanation. Can I conclude that during training, when I manually set the batch_size to 16 for example, this will not have a negative impact on training due to attention mask issues? Or am I mistaken and did I read your comment wrong?

rungjoo commented 1 year ago

There may be a negative impact, but it is thought to be small. To remove the effect, the attention corresponding to the padding token must be set to 0.

You got it right.