Georgetown-IR-Lab / cedr

Code for CEDR: Contextualized Embeddings for Document Ranking, accepted at SIGIR 2019.
MIT License
155 stars 28 forks source link

transformers rather than pytorch_pretrained_bert #28

Open cmacdonald opened 4 years ago

cmacdonald commented 4 years ago

We'd like to make use of the more generic transformers library. There is some migration information at

We're trying to upgrade a BertRanker:

class VanillaBertTransformerRanker(BertRanker):
    def __init__(self):

        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.bert = AutoModel.from_pretrained('bert-base-uncased')

        self.dropout = torch.nn.Dropout(0.1)
        self.cls = torch.nn.Linear(self.BERT_SIZE, 1)

        for layer in result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):

            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)

We needed to do some casting in, e.g.:


However, the shapes get a bit out of kilter in encode_bert: before the stack, we end up with size [4] rather than [4,768]

# vanilla_bert
# shape of first tensor: torch.Size([283, 768])
# shape of first result: torch.Size([8, 283, 768])
# shape of first cls_result: torch.Size([4, 768])
# shape of first cls_result before stack: torch.Size([4, 768])
# shape of cls_result before stack: 2
# shape of first cls_result after stack: torch.Size([768])
# shape of cls_result after stack: 4
# shape of first cls_result: torch.Size([4, 768])
# shape of first cls_result before stack: torch.Size([4, 768])
# shape of cls_result before stack: 2
# ...

# transformer bert
# shape of first tensor: torch.Size([283, 768])
# shape of first result: torch.Size([8, 283, 768])
# shape of first cls_result before stack: torch.Size([4, 768])
# shape of cls_result before stack: 2
# shape of first cls_result after stack: torch.Size([768])
# shape of cls_result after stack: 4
# shape of first cls_result before stack: torch.Size([4])
# shape of cls_result before stack: 2
# ---> error


petulla commented 4 years ago

@cmacdonald Did you migrate to HuggingFace? I'm interested in trying your pretrained model once it's released.