lingochamp / Multi-Scale-BERT-AES

Demo for the paper "On the Use of BERT for Automated Essay Scoring: Joint Learning of Multi-Scale Essay Representation"
56 stars 13 forks source link

Code is right ??? #6

Open smallsmallwood opened 1 year ago

smallsmallwood commented 1 year ago

orginal code:

    def forward(self, document_batch: torch.Tensor, device='cpu', bert_batch_size=0):
        **bert_output = torch.zeros(size=(document_batch.shape[0],
                                        min(document_batch.shape[1],
                                            bert_batch_size),
                                        self.bert.config.hidden_size), dtype=torch.float, device=device)
        for doc_id in range(document_batch.shape[0]):
            bert_output[doc_id][:bert_batch_size] = self.dropout(self.bert(document_batch[doc_id][:bert_batch_size,0],
                                                                           token_type_ids=document_batch[doc_id][:bert_batch_size, 1],
                                                                           attention_mask=document_batch[doc_id][:bert_batch_size, 2])[1])**
        output, (_, _) = self.lstm(bert_output.permute(1, 0, 2))
        output = output.permute(1, 0, 2)
        # (batch_size, seq_len, num_hiddens)
        attention_w = torch.tanh(torch.matmul(output, self.w_omega) + self.b_omega)
        attention_u = torch.matmul(attention_w, self.u_omega)  # (batch_size, seq_len, 1)
        attention_score = F.softmax(attention_u, dim=1)  # (batch_size, seq_len, 1)
        attention_hidden = output * attention_score  # (batch_size, seq_len, num_hiddens)
        attention_hidden = torch.sum(attention_hidden, dim=1)  # 加权求和 (batch_size, num_hiddens)
        prediction = self.mlp(attention_hidden)
        assert prediction.shape[0] == document_batch.shape[0]
        return prediction

modified:

    def forward(self, document_batch: torch.Tensor, device='cpu', bert_batch_size=0):
        **bert_output = torch.zeros(size=(document_batch.shape[0],
                                        # min(document_batch.shape[1], bert_batch_size),
                                        document_batch.shape[1],
                                        self.bert.config.hidden_size), dtype=torch.float, device=device)
        for doc_id in range(document_batch.shape[0]):
            bert_output[doc_id][:document_batch.shape[1]] = self.dropout(self.bert(document_batch[doc_id][:document_batch.shape[1], 0],
                                                                                   token_type_ids=document_batch[doc_id][:document_batch.shape[1], 1],
                                                                                   attention_mask=document_batch[doc_id][:document_batch.shape[1], 2])[1])**
        output, (_, _) = self.lstm(bert_output.permute(1, 0, 2))
        output = output.permute(1, 0, 2)
        # (batch_size, seq_len, num_hiddens)
        attention_w = torch.tanh(torch.matmul(output, self.w_omega) + self.b_omega)
        attention_u = torch.matmul(attention_w, self.u_omega)  # (batch_size, seq_len, 1)
        attention_score = F.softmax(attention_u, dim=1)  # (batch_size, seq_len, 1)
        attention_hidden = output * attention_score  # (batch_size, seq_len, num_hiddens)
        attention_hidden = torch.sum(attention_hidden, dim=1)  # 加权求和 (batch_size, num_hiddens)
        prediction = self.mlp(attention_hidden)
        assert prediction.shape[0] == document_batch.shape[0]
        return prediction
iamhere1 commented 1 year ago

If the longest essays in the batch have the same length, both codes are ok. The value of 'bert_batch_size' is set by the length of essay prompts, which may be more stable, while the value of 'document_batch.shape[1]' is dependent on the length of the longest essays.