simonw / llm-sentence-transformers

LLM plugin for embeddings using sentence-transformers
Apache License 2.0
42 stars 3 forks source link

docs should mention limitation of sbert #8

Closed thiswillbeyourgithub closed 1 year ago

thiswillbeyourgithub commented 1 year ago

Hi,

I encounter time and time again people disappointed by the effectiveness of the sentence-transformers models. Usually the reason being that the models have very short "max sequence length" (the default model is 256) and everything after that is silently clipped.

GIven that this happens silently, I think most people are not aware of that. And the multilingual models have even shorted length!

I brought that up several times here on langchain and there too.

So I think it would be good to mention this in the README.md.

And if anyone is down for writing a simple wrapper that does a rolling average/maxpooling/whateverpooling of the input instead of clipping it that would be awesome! That would be a workaround that can't possibly be worse than just clipping the input right?

Cheers and llm is great!

(related to https://github.com/simonw/llm/issues/220)

simonw commented 1 year ago

Thanks, that's a good tip - I've added it to the "usage" section.

thiswillbeyourgithub commented 1 year ago

I'm sharing my own 'rolling' sbert script to avoid clipping the sentences. It's seemingly functionnal but not very elegent, a class would be better of course but I just hope it helps someone :

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-mpnet-base-v2")

# sbert silently crops any token above the max_seq_length,
# so we do a windowing embedding then sum. The normalization happens
# afterwards.
def encode_sentences(sentences):
      max_len = model.get_max_seq_length()

    if not isinstance(max_len, int):
        # the clip model has a different way to use the encoder
        # sources : https://github.com/UKPLab/sentence-transformers/issues/1269
        assert "clip" in str(model).lower(), f"sbert model with no 'max_seq_length' attribute and not clip: '{model}'"
        max_len = 77
        encode = model._first_module().processor.tokenizer.encode
    else:
        if hasattr(model.tokenizer, "encode"):
            # most models
            encode = model.tokenizer.encode
        else:
            # word embeddings models like glove
            encode = model.tokenizer.tokenize

    assert isinstance(max_len, int), "n must be int"
    n23 = (max_len * 2) // 3
    add_sent = []  # additional sentences
    add_sent_idx = []  # indices to keep track of sub sentences

    for i, s in enumerate(sentences):
        # skip if the sentence is short
        length = len(encode(s))
        if length <= max_len:
            continue

        # otherwise, split the sentence at regular interval
        # then do the embedding of each
        # and finally maxpool those sub embeddings together
        # the renormalization happens later in the code
        sub_sentences = []
        words = s.split(" ")
        avg_tkn = length / len(words)
        j = int(max_len / avg_tkn * 0.8)  # start at 90% of the supposed max_len
        while len(encode(" ".join(words))) > max_len:

            # if reached max length, use that minus one word
            until_j = len(encode(" ".join(words[:j])))
            if until_j >= max_len:
                jjj = 1
                while len(encode(" ".join(words[:j-jjj]))) >= max_len:
                    jjj += 1
                sub_sentences.append(" ".join(words[:j-jjj]))

                # remove first word until 1/3 of the max_token was removed
                # this way we have a rolling window
                jj = int((max_len // 3) / avg_tkn * 0.8)
                while len(encode(" ".join(words[jj:j-jjj]))) > n23:
                    jj += 1
                words = words[jj:]

                j = int(max_len / avg_tkn * 0.8)
            else:
                diff = abs(max_len - until_j)
                if diff > 10:
                    j += max(1, int(10 / avg_tkn))
                else:
                    j += 1

        sub_sentences.append(" ".join(words))

        sentences[i] = " "  # discard this sentence as we will keep only
        # the sub sentences maxpooled

        # remove empty text just in case
        if "" in sub_sentences:
            while "" in sub_sentences:
                sub_sentences.remove("")
        assert sum([len(encode(ss)) > max_len for ss in sub_sentences]) == 0, f"error when splitting long sentences: {sub_sentences}"
        add_sent.extend(sub_sentences)
        add_sent_idx.extend([i] * len(sub_sentences))

    if add_sent:
        sent_check = [
                len(encode(s)) > max_len
                for s in sentences
                ]
        addsent_check = [
                len(encode(s)) > max_len
                for s in add_sent
                ]
        assert sum(sent_check + addsent_check) == 0, (
            f"The rolling average failed apparently:\n{sent_check}\n{addsent_check}")
    vectors = vectorizer(
            sentences=sentences + add_sent,
            show_progress_bar=True,
            output_value="sentence_embedding",
            convert_to_numpy=True,
            normalize_embeddings=False,
            )

    if add_sent:
        # at the position of the original sentence (not split)
        # add the vectors of the corresponding sub_sentence
        # then return only the 'maxpooled' section
        assert len(add_sent) == len(add_sent_idx), (
            "Invalid add_sent length")
        offset = len(sentences)
        for sid in list(set(add_sent_idx)):
            id_range = [i for i, j in enumerate(add_sent_idx) if j == sid]
            add_sent_vec = vectors[
                    offset + min(id_range): offset + max(id_range), :]
            vectors[sid] = np.amax(add_sent_vec, axis=0)
        return vectors[:offset]
    else:
        return vectors

edit: fixed the code :/

Jakobhenningjensen commented 8 months ago

@thiswillbeyourgithub

I like the idea of that "rolling window".

Is it the way to do it, or is it just to have an alternative to clipping i.e how well does it work?

thiswillbeyourgithub commented 8 months ago

Keep in mind that my implementation is pretty naive and can certainly be vastly optimized but the idea is there. An implementation for langchain can be found here. A nonlangchain implementation in a code I'm using regularly can be found here

I don't know if I have found by chance the way but probably not. In my example I did a maxpooling but I could have done a meanpooling instead. That also brings about the question of L1 vs L2 if doing a normalization. Also one can think about having an exponential decay of the importance of each new token of text etc.

In my experience: Maxpooling or meanpooling seem to work fine. In my opinion: Any kind of rolling window seems in theory vastly superior to silently cropping and the difference between each windowing method is probably negligeable compared to cropping.

More tests would be needed with proper metrics to find out which is best and if the enhancement is not just placebo that degrade results.