quarkiverse / quarkus-langchain4j

Quarkus Langchain4j extension
https://docs.quarkiverse.io/quarkus-langchain4j/dev/index.html
Apache License 2.0
115 stars 64 forks source link

Embed RAG data closer to source data? #675

Open FroMage opened 2 weeks ago

FroMage commented 2 weeks ago

The current RAG model for pgvector is to store the documents in their own table.

In my application my source documents already have a table:

@Entity
public class Talk extends PanacheEntity {
 public String description;
 public String title;
}

So, when I iterate those to index them, they all go in their separate table:

// DO LLM
Log.infof("Loading data from talks for LLM");
List<Talk> talks = Talk.listAll();
List<Document> documents = new ArrayList<>();
store.removeAll();
Log.infof("Documents: %d", talks.size());
for (Talk talk : talks) {
  Map<String, String> metadata = new HashMap<>();
  metadata.put("title", talk.title);
  metadata.put("id", talk.id.toString());
  if(!talk.description.isBlank()) {
    documents.add(new Document("Title: "+talk.title+"\nID: "+talk.id+"\nDescription: "+talk.description, Metadata.from(metadata)));
  } else {
    Log.infof("Skipping talk %s", talk.getTitle());
  }
}
Log.infof("Injesting LLM");
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
                .embeddingStore(store)
                .embeddingModel(embeddingModel)
                .documentSplitter(DocumentSplitters.recursive(2000, 0))
                .build();
// Warning - this can take a long time...
ingestor.ingest(documents);
Log.infof("Injesting LLM done");

This leads me to wonder how I can keep my model and the index in sync. What do I do when I update a single Talk entity? Do I need to re-index the entire store?

Intuitively, I was expecting to be able to do something like:

@Entity
public class Talk extends PanacheEntity {
 public String description;
 public String title;
@JdbcTypeCode(SqlTypes.VECTOR)
@Array(length = 3)
public float[] myvector;

 @IndexProducer
 public Document getDocumentForIndex(){
   if(!description.isBlank()) {
            Map<String, String> metadata = new HashMap<>();
            metadata.put("title", talk.title);
            metadata.put("id", talk.id.toString());
     return new Document("Title: "+title+"\nID: "+id+"\nDescription: "+description, Metadata.from(metadata)));
    } else {
      return null;
    }
 }

    @PreUpdate
    @PrePersist
    public void prePersist() {
        // tell langchain4j to reindex me, somehow
    }
}

But I'm not too sure how to wire this up.

I suppose that a PanacheEmbeddingStore could have this sort of API for batch reindex, given this:

// DO LLM
Log.infof("Loading data from talks for LLM");
List<Document> documents = store.getDocumentsForModel(Talk.class);
store.removeAll();
Log.infof("Injesting LLM");
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
                .embeddingStore(store)
                .embeddingModel(embeddingModel)
                .documentSplitter(DocumentSplitters.recursive(2000, 0))
                .build();
// Warning - this can take a long time...
ingestor.ingest(documents);
Log.infof("Injesting LLM done");
geoand commented 2 weeks ago

cc @jmartisk

jmartisk commented 2 weeks ago

Quickly updating the corresponding embeddings if one of the source documents changes is unfortunately not trivial, because there's no 'equality' operator in this world. So, to keep track of which embeddings are related to which document, you probably need to use some sort of ID. That ID can be part of the metadata.

Could you perhaps, in the PrePersist method, take the ID of the document that is being updated, do something like

// remove the original embeddings
embeddingStore.removeAll(new IsEqualTo("id" , talk.id));

// create the new embedding
Map<String, String> metadata = new HashMap<>();
metadata.put("title", talk.title);
metadata.put("id", talk.id.toString());
Document document = new Document("Title: "+talk.title+"\nID: "+talk.id+"\nDescription: "+talk.description, Metadata.from(metadata));
List<TextSegment> split = DocumentSplitters.recursive(2000, 0).split(document);
List<Embedding> newEmbeddings = embeddingModel.embedAll(split).content();
store.addAll(newEmbeddings);

This also depends on the particualar embedding store implementing the removeAll method though, and I think PgVector doesn't support it right now :/

jmartisk commented 2 weeks ago

I've just noticed there's also the related https://github.com/langchain4j/langchain4j/issues/1299 (except it's about using the embedding ID instead of one stored as part of the metadata)

langchain4j commented 2 weeks ago

PgVectorEmbeddingStore supports removeAll(Filter) where Filter can be metadataKey("id").isEqualTo(talk.id)