Closed nandinib1999 closed 4 years ago
Could you post a complete (runnable) example script, including what type of data arr
contains.
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import SentenceEvaluator, SimilarityFunction
from typing import List
from sentence_transformers.readers import InputExample
import gzip
import csv
import torch
import os
from sentence_transformers import util
class Evaluator(SentenceEvaluator):
def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], device = None, batch_size: int = 16, name: str = '', show_progress_bar: bool = False):
self.sentences1 = sentences1
self.sentences2 = sentences2
self.score = scores
self.batch_size = batch_size
self.show_progress_bar = show_progress_bar
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
@classmethod
def from_input_examples(cls, examples: List[InputExample], **kwargs):
sentences1 = []
sentences2 = []
scores = []
for example in examples:
sentences1.append(example.texts[0])
sentences2.append(example.texts[1])
scores.append(example.label)
return cls(sentences1, sentences2, scores, **kwargs)
def __call__(self, model):
assert len(self.sentences1) == len(self.sentences2), "Different number of Candidate and Reference Sentences"
sentences = dedup_and_sort(self.sentences1 + self.sentences2)
for batch_start in range(0, len(sentences), self.batch_size):
sen_batch = sentences[batch_start : batch_start + self.batch_size]
embs, masks, padded_idf = get_bert_embeddings(sen_batch, model, self.batch_size)
def get_bert_embeddings(sentences, model, batch_size):
arr = [model.tokenize(sent) for sent in sentences]
arr = [list(a) for a in arr]
for i in range(0, len(sentences), batch_size):
print(arr[i : i + batch_size])
b_emb = model.encode(arr[i : i + batch_size], output_value='token_embeddings', convert_to_tensor=True, is_pretokenized=True, show_progress_bar=False)
embeddings.append(b_emb)
total_embedding = torch.cat(embeddings, dim=0)
return total_embedding
def dedup_and_sort(l):
return sorted(list(set(l)), key=lambda x: len(x.split(" ")), reverse=True)
if not os.path.exists('datasets'):
os.mkdir('datasets')
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
if not os.path.exists(sts_dataset_path):
util.http_get('https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
dev_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
if row['split'] == 'dev':
score = float(row['score']) / 5.0 #Normalize score to range 0 ... 1
dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))
model = SentenceTransformer('bert-base-nli-mean-tokens')
evaluator = Evaluator.from_input_examples(dev_samples, batch_size=64, show_progress_bar=True)
evaluator(model)
Using this you can reproduce the error.
arr is a list of tokenized sentences.
@nreimers - Hey, any help here will be much appreciated.
@nreimers - I am running on a deadline, it would be really helpful if you could tell me about any progress on the above issue. Thanks.
@nandinib1999 You can fix the error by changing below line
b_emb = model.encode(arr[i : i + batch_size], output_value='token_embeddings', convert_to_tensor=True, is_pretokenized=True, show_progress_bar=False)
with
b_emb = model.encode(arr[i: i + batch_size], output_value='token_embeddings', convert_to_tensor=True, is_pretokenized=True, show_progress_bar=False, batch_size=batch_size)
Yes, it is working now. Thank you! @jicksonp
Hi,
I was using the 'token_embeddings' feature of the sentence transformers and ran into this error. I have passed a list of sentences to the model.encode() for generating token_embeddings. Are the sentences not being padded before getting converted into token_embeddings? Can you suggest a workaround to this?
b_emb = model.encode(arr[i : i + batch_size], output_value='token_embeddings', convert_to_tensor=True, is_pretokenized=True, show_progress_bar=False)
This is the exact code I am using.
Thanks