THUDM / P-tuning

A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''.
MIT License
923 stars 111 forks source link

BERT output is a tuple (in LAMA) #24

Closed olenmg closed 3 years ago

olenmg commented 3 years ago

Hi, thanks for great codebase!

in bert_out method in PTuneForLAMA class,

# LAMA/p_tuning/modeling.py (124 line~)
def bert_out():
    label_mask = (queries == self.tokenizer.mask_token_id).nonzero().reshape(bz, -1)[:, 1].unsqueeze(
        1).to(self.device)  # bz * 1
    labels = torch.empty_like(queries).fill_(-100).long().to(self.device)  # bz * seq_len
    labels = labels.scatter_(1, label_mask, label_ids)
    output = self.model(inputs_embeds=inputs_embeds.to(self.device),
                        attention_mask=attention_mask.to(self.device).bool(),
                        labels=labels.to(self.device))
    loss, logits = output.loss, output.logits

output object has no attributes(loss, logits) since it is tuple

I think it should be changed like below

def bert_out():
    label_mask = (queries == self.tokenizer.mask_token_id).nonzero().reshape(bz, -1)[:, 1].unsqueeze(
        1).to(self.device)  # bz * 1
    labels = torch.empty_like(queries).fill_(-100).long().to(self.device)  # bz * seq_len
    labels = labels.scatter_(1, label_mask, label_ids)
    loss, logits = self.model(inputs_embeds=inputs_embeds.to(self.device),
                        attention_mask=attention_mask.to(self.device).bool(),
                        labels=labels.to(self.device))

I checked this code works fine on my machine. Thank you again.


07.08 add gpt_out() also has a same issue

loss, logits, _ = self.model(inputs_embeds=inputs_embeds.to(self.device).half(),
                    attention_mask=attention_mask.to(self.device).half(),
                    labels=labels.to(self.device))

If the huggingface Transformers version is higher, it can be solved by giving the return_dict option True

Xiao9905 commented 3 years ago

Thanks for your response. I think it should be a package version problem from huggingface transformers. This codes should work fine with version after 4.3.0.

Xiao9905 commented 3 years ago

Thanks again for your careful checking! I think this may help other users to run the codes.

olenmg commented 3 years ago

I raised this issue because the Transformers version is 3.0.2 in requirements.txt. Thanks for answering.