Closed Claymore715 closed 2 years ago
def bert_out(): label_mask = (queries == self.tokenizer.mask_token_id).nonzero().reshape(bz, -1)[:, 1].unsqueeze( 1).to(self.device) # bz * 1 labels = torch.empty_like(queries).fill_(-100).long().to(self.device) # bz * seq_len labels = labels.scatter_(1, label_mask, label_ids) output=self.model(inputs_embeds=inputs_embeds.to(self.device), attention_mask=attention_mask.to(self.device).bool(), labels=labels.to(self.device)) loss, logits = output.loss, output.logits
@renyuanzhe Hi,
You seem to use a different version of huggingface transformers with ours. Please try using version provided in the requirements.txt.