class VanillaBertTransformerRanker(BertRanker):
def __init__(self):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.bert = AutoModel.from_pretrained('bert-base-uncased')
self.dropout = torch.nn.Dropout(0.1)
self.cls = torch.nn.Linear(self.BERT_SIZE, 1)
<snip>
for layer in result:
cls_output = layer[:, 0]
cls_result = []
for i in range(cls_output.shape[0] // BATCH):
cls_result.append(cls_output[i*BATCH:(i+1)*BATCH])
cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
cls_results.append(cls_result)
We needed to do some casting in train.py, e.g.:
torch.tensor(record['query_tok']).to(torch.int64)
However, the shapes get a bit out of kilter in encode_bert: before the stack, we end up with size [4] rather than [4,768]
# vanilla_bert
# shape of first tensor: torch.Size([283, 768])
# shape of first result: torch.Size([8, 283, 768])
# shape of first cls_result: torch.Size([4, 768])
# shape of first cls_result before stack: torch.Size([4, 768])
# shape of cls_result before stack: 2
# shape of first cls_result after stack: torch.Size([768])
# shape of cls_result after stack: 4
# shape of first cls_result: torch.Size([4, 768])
# shape of first cls_result before stack: torch.Size([4, 768])
# shape of cls_result before stack: 2
# ...
# transformer bert
# shape of first tensor: torch.Size([283, 768])
# shape of first result: torch.Size([8, 283, 768])
# shape of first cls_result before stack: torch.Size([4, 768])
# shape of cls_result before stack: 2
# shape of first cls_result after stack: torch.Size([768])
# shape of cls_result after stack: 4
# shape of first cls_result before stack: torch.Size([4])
# shape of cls_result before stack: 2
# ---> error
We'd like to make use of the more generic transformers library. There is some migration information at https://huggingface.co/transformers/migration.html
We're trying to upgrade a BertRanker:
We needed to do some casting in train.py, e.g.:
However, the shapes get a bit out of kilter in encode_bert: before the stack, we end up with size [4] rather than [4,768]
+@albertoueda