Closed thiswillbeyourgithub closed 1 year ago
Thanks, that's a good tip - I've added it to the "usage" section.
I'm sharing my own 'rolling' sbert script to avoid clipping the sentences. It's seemingly functionnal but not very elegent, a class would be better of course but I just hope it helps someone :
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-mpnet-base-v2")
# sbert silently crops any token above the max_seq_length,
# so we do a windowing embedding then sum. The normalization happens
# afterwards.
def encode_sentences(sentences):
max_len = model.get_max_seq_length()
if not isinstance(max_len, int):
# the clip model has a different way to use the encoder
# sources : https://github.com/UKPLab/sentence-transformers/issues/1269
assert "clip" in str(model).lower(), f"sbert model with no 'max_seq_length' attribute and not clip: '{model}'"
max_len = 77
encode = model._first_module().processor.tokenizer.encode
else:
if hasattr(model.tokenizer, "encode"):
# most models
encode = model.tokenizer.encode
else:
# word embeddings models like glove
encode = model.tokenizer.tokenize
assert isinstance(max_len, int), "n must be int"
n23 = (max_len * 2) // 3
add_sent = [] # additional sentences
add_sent_idx = [] # indices to keep track of sub sentences
for i, s in enumerate(sentences):
# skip if the sentence is short
length = len(encode(s))
if length <= max_len:
continue
# otherwise, split the sentence at regular interval
# then do the embedding of each
# and finally maxpool those sub embeddings together
# the renormalization happens later in the code
sub_sentences = []
words = s.split(" ")
avg_tkn = length / len(words)
j = int(max_len / avg_tkn * 0.8) # start at 90% of the supposed max_len
while len(encode(" ".join(words))) > max_len:
# if reached max length, use that minus one word
until_j = len(encode(" ".join(words[:j])))
if until_j >= max_len:
jjj = 1
while len(encode(" ".join(words[:j-jjj]))) >= max_len:
jjj += 1
sub_sentences.append(" ".join(words[:j-jjj]))
# remove first word until 1/3 of the max_token was removed
# this way we have a rolling window
jj = int((max_len // 3) / avg_tkn * 0.8)
while len(encode(" ".join(words[jj:j-jjj]))) > n23:
jj += 1
words = words[jj:]
j = int(max_len / avg_tkn * 0.8)
else:
diff = abs(max_len - until_j)
if diff > 10:
j += max(1, int(10 / avg_tkn))
else:
j += 1
sub_sentences.append(" ".join(words))
sentences[i] = " " # discard this sentence as we will keep only
# the sub sentences maxpooled
# remove empty text just in case
if "" in sub_sentences:
while "" in sub_sentences:
sub_sentences.remove("")
assert sum([len(encode(ss)) > max_len for ss in sub_sentences]) == 0, f"error when splitting long sentences: {sub_sentences}"
add_sent.extend(sub_sentences)
add_sent_idx.extend([i] * len(sub_sentences))
if add_sent:
sent_check = [
len(encode(s)) > max_len
for s in sentences
]
addsent_check = [
len(encode(s)) > max_len
for s in add_sent
]
assert sum(sent_check + addsent_check) == 0, (
f"The rolling average failed apparently:\n{sent_check}\n{addsent_check}")
vectors = vectorizer(
sentences=sentences + add_sent,
show_progress_bar=True,
output_value="sentence_embedding",
convert_to_numpy=True,
normalize_embeddings=False,
)
if add_sent:
# at the position of the original sentence (not split)
# add the vectors of the corresponding sub_sentence
# then return only the 'maxpooled' section
assert len(add_sent) == len(add_sent_idx), (
"Invalid add_sent length")
offset = len(sentences)
for sid in list(set(add_sent_idx)):
id_range = [i for i, j in enumerate(add_sent_idx) if j == sid]
add_sent_vec = vectors[
offset + min(id_range): offset + max(id_range), :]
vectors[sid] = np.amax(add_sent_vec, axis=0)
return vectors[:offset]
else:
return vectors
edit: fixed the code :/
@thiswillbeyourgithub
I like the idea of that "rolling window".
Is it the way to do it, or is it just to have an alternative to clipping i.e how well does it work?
Keep in mind that my implementation is pretty naive and can certainly be vastly optimized but the idea is there. An implementation for langchain can be found here. A nonlangchain implementation in a code I'm using regularly can be found here
I don't know if I have found by chance the way but probably not. In my example I did a maxpooling but I could have done a meanpooling instead. That also brings about the question of L1 vs L2 if doing a normalization. Also one can think about having an exponential decay of the importance of each new token of text etc.
In my experience: Maxpooling or meanpooling seem to work fine. In my opinion: Any kind of rolling window seems in theory vastly superior to silently cropping and the difference between each windowing method is probably negligeable compared to cropping.
More tests would be needed with proper metrics to find out which is best and if the enhancement is not just placebo that degrade results.
Hi,
I encounter time and time again people disappointed by the effectiveness of the sentence-transformers models. Usually the reason being that the models have very short "max sequence length" (the default model is 256) and everything after that is silently clipped.
GIven that this happens silently, I think most people are not aware of that. And the multilingual models have even shorted length!
I brought that up several times here on langchain and there too.
So I think it would be good to mention this in the README.md.
And if anyone is down for writing a simple wrapper that does a rolling average/maxpooling/whateverpooling of the input instead of clipping it that would be awesome! That would be a workaround that can't possibly be worse than just clipping the input right?
Cheers and llm is great!
(related to https://github.com/simonw/llm/issues/220)