Anush008 / fastembed-go

Go implementation of @Qdrant/fastembed.
https://pkg.go.dev/github.com/anush008/fastembed-go
MIT License
37 stars 3 forks source link

Your library generates Incorrect embeddings #12

Open 0110G opened 1 month ago

0110G commented 1 month ago

Hi. I am getting different embeddings when

  1. Python, sentence encoder
    model_standard = SentenceTransformer("all-MiniLM-L6-v2")
    print(model_standard.encode("Hello World"))

Output: [-3.44772786e-02 3.10232081e-02 6.73494861e-03 2.61090137e-02 -3.93620506e-02 -1.60302490e-01 6.69240057e-02 -6.44144369e-03...

  1. Using your lib:
    
    package main

import ( "github.com/anush008/fastembed-go" "fmt" )

func main() {

// With custom options
options := fastembed.InitOptions{
    Model:     fastembed.AllMiniLML6V2,
    CacheDir:  "model_cache",
    MaxLength: 2000,
}

model, err := fastembed.NewFlagEmbedding(&options)
if err != nil {
    panic(err)
}
defer model.Destroy()

documents := []string{
    "Hello World",
}

// Generate embeddings with a batch-size of 25, defaults to 256
embeddings, err := model.Embed(documents, 1) //  -> Embeddings length: 4
if err != nil {
    panic(err)
} else {
    fmt.Println(embeddings)
}

}


Output: ```[[0.025913946 0.0057322374 0.01147225 0.037964534 -0.023283843 -0.05493553 0.014040766 ....```

I have tried PassageEmbed also but no use. How to make these two match?
Anush008 commented 1 month ago

It may be because of the different normalization of all-MiniLM-L6-v2 . Can you try a different model? For exammple BGESmall and let me know.

0110G commented 1 month ago

My use case involves minilm only.

Anush008 commented 1 month ago

Okay. It is to pin-point the issue.

0110G commented 1 month ago

Mostly your implementation does not take care of mean_pooling

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

See: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2

This is something that I didnt face in qdrants python fast embed implementation

Anush008 commented 1 month ago

Yes. FastEmbed does this separately for this model. https://github.com/qdrant/fastembed/blob/9387ca320507737306ba6ec663398813158f8e29/fastembed/text/mini_lm_embedding.py#L50-L54

It's not updated here.

0110G commented 1 month ago

Would you be able to add this support. It will be really helpful. Noticed this issue in rust implementation as well

Anush008 commented 1 month ago

I work on these projects on the weekends. So then most probably.

0110G commented 1 month ago

The output this lib returns doesn't even match torch model output before mean pooling 😅

Anush008 commented 1 month ago

Now that's unexpected.

0110G commented 1 month ago

Will you be able to fix this?

Anush008 commented 1 month ago

I can take a look during the weekends.

0110G commented 1 month ago

Any updates

Anush008 commented 1 month ago

I was at work and couldn't get to this. If you're willing, I want you to know that contributions are welcome.