elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.38k stars 103 forks source link

Cross Encoder support #251

Open michalwarda opened 1 year ago

michalwarda commented 1 year ago

Hi, I'm currently trying to implement a feature called "hybrid search" inside of my application. It's based on returning query results from multiple databases and later scoring results together from multiple sources. To score them I want to use cross-encoder models ie. https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2.

I'm trying to understand if Bumblebee currently supports models like this and if so how to use it for that.

If it does I'll be very happy to write some documentation for that after getting some hints. If not it would be a very cool feature to handle those types of operations :)

jonatanklosko commented 1 year ago

Hey @michalwarda! The repository uses the BERT, so Bumblebee supports it, however we don't have any serving that fits cross-encoder. So currently you could do this:

{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

inputs =
  Bumblebee.apply_tokenizer(tokenizer, [
    {"How many people live in Berlin?",
     "Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."},
    {"How many people live in Berlin?",
     "New York City is famous for the Metropolitan Museum of Art."}
  ])

outputs = Axon.predict(model_info.model, model_info.params, inputs)

outputs.logits
#=> #Nx.Tensor<
#=>   f32[2][1]
#=>   EXLA.Backend<fake_gpu:0, 0.1151319922.1832779796.76254>
#=>   [
#=>     [8.845853805541992],
#=>     [-11.245560646057129]
#=>   ]
#=> >

I think we can add a serving like Bumblebee.Text.cross_encoding to optimise for this use case. Ideally we would also open PRs with tokenizer.json in the HF repositories, because in this case it's far from obvious that bert-base-uncased is the place to look for.

samrat commented 1 year ago

Thank you for the hint about bert-base-uncased :)

I am also interested in this use-case. Is this a feature that will be added?

jonatanklosko commented 1 year ago

Thank you for the hint about bert-base-uncased :)

We now have a more specific error message when a tokenizer is missing in the repository and suggested steps to get a compatible one, so hopefully it should be more intuitive without guessing repositories :)

I am also interested in this use-case. Is this a feature that will be added?

Yeah, I think Bumblebee.Text.cross_encoding makes sense. It's not the top priority right now, but contributions are also welcome.

toranb commented 1 year ago

Huge thanks to @jonatanklosko for sharing this solution! I ran into this today and wanted to share a working Nx serving I put together for my use case (RAG with mixed search using Postgres PGVector and full text search together)

https://github.com/toranb/rag-n-drop/commit/4b515edbc24a5d6fbb32966ea64e9e35030f7365

I included a working use case for those who might bump into this later on :)