shenweichen / DeepCTR-Torch

【PyTorch】Easy-to-use,Modular and Extendible package of deep-learning based CTR models.
https://deepctr-torch.readthedocs.io/en/latest/index.html
Apache License 2.0
3.04k stars 707 forks source link

I had this error" RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor", when i try to add LSTM on the top of bert #275

Open Nada-1234 opened 1 year ago

Nada-1234 commented 1 year ago

class SentencePairClassifier(nn.Module):

    def __init__(self, bert_model="bert-base-multilingual-cased", freeze_bert=False):
        super(SentencePairClassifier, self).__init__()
        #  Instantiating BERT-based model object
        self.bert_layer = AutoModel.from_pretrained(bert_model)
            self.hidden_size = 768

        self.LSTM = nn.LSTM(self.hidden_size,self.hidden_size,bidirectional=True)

        # Freeze bert layers and only train the classification layer weights
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False

        # Classification layer
        self.cls_layer = nn.Linear(self.hidden_size*2,1)

        # self.dropout = nn.Dropout(p=0.1)

    @autocast()  # run in mixed precision
    def forward(self, input_ids, attn_masks, token_type_ids):

        # Feeding the inputs to the BERT-based model to obtain contextualized representations
        cont_reps, pooler_output = self.bert_layer(input_ids, attn_masks)
        cont_reps = cont_reps.permute(1, 0, 2)

        enc_hiddens, (last_hidden, last_cell) = self.LSTM(pack_padded_sequence(cont_reps,token_type_ids))

        output_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)

        output_hidden = F.dropout(output_hidden,0.2)
        output = self.cls_layer(output_hidden)
        return F.sigmoid(output)

I tried a lot but it doesn't work please help

zanshuxun commented 1 year ago

Please add more infomation for us to reproduce the error.

Describe the bug(问题描述) A clear and concise description of what the bug is.

To Reproduce(复现步骤) Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Operating environment(运行环境):

Additional context Add any other context about the problem here.