langgenius / dify

Dify is an open-source LLM app development platform. Dify's intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
https://dify.ai
Other
45.26k stars 6.35k forks source link

Performance issues in knowledge base construction. #7491

Open codingMrHu opened 3 weeks ago

codingMrHu commented 3 weeks ago

Self Checks

Dify version

main

Cloud or Self Hosted

Self Hosted (Source)

Steps to reproduce

When the Celery task processes and saves knowledge base files, it stalls and takes an excessively long time when the paragraph data reaches the scale of hundreds of thousands.

IndexingRunner().run() has the following code, self._load_segments insert table document_segments of PG ,but next func self._load call _update_segment_keywords() to update each document_segments , it cost a lot time. i think it's better to get keywords when insert document_segments and save it.


# transform
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
                            processing_rule.to_dict())

# save segment
self._load_segments(dataset, dataset_document, documents)

# load
self._load(
    index_processor=index_processor,
    dataset=dataset,
    dataset_document=dataset_document,
    documents=documents
)

api/core/rag/datasource/keyword/jieba/jieba.py:_update_segment_keywords

def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
        with redis_client.lock(lock_name, timeout=600):
            keyword_table_handler = JiebaKeywordTableHandler()
            keyword_table = self._get_dataset_keyword_table()
            for text in texts:
                keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
                self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
                keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

            self._save_dataset_keyword_table(keyword_table)

            return self

✔️ Expected Behavior

No response

❌ Actual Behavior

No response

dosubot[bot] commented 3 weeks ago

Hello @codingMrHu! I'm here to help you with bugs, questions, and becoming a contributor. Let's tackle this issue together while we wait for a human maintainer.

To address the performance issues in the IndexingRunner().run() method, particularly with the _update_segment_keywords() function, you can optimize the process by extracting and saving keywords during the segment insertion phase. Here are some strategies to improve performance:

  1. Extract Keywords During Segment Insertion: Modify the _load_segments method to include keyword extraction and update the segments with keywords during insertion.

  2. Batch Processing: Process documents in batches to reduce the overhead of database operations.

  3. Parallel Processing: Utilize multi-threading or multi-processing to handle multiple segments concurrently.

Here is an example of how you can implement these optimizations:

Modified IndexingRunner Class

import concurrent.futures
import datetime
import logging
from your_project.models import Dataset, DatasetProcessRule, DocumentSegment, DatasetDocumentStore
from your_project.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler

class IndexingRunner:

    def __init__(self):
        self.storage = storage
        self.model_manager = ModelManager()

    def run(self, dataset_documents: list[DatasetDocument]):
        """Run the indexing process."""
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(self._process_document, dataset_document) for dataset_document in dataset_documents]
            for future in concurrent.futures.as_completed(futures):
                try:
                    future.result()
                except Exception as e:
                    logging.exception("Error processing document: %s", e)

    def _process_document(self, dataset_document):
        try:
            dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
            if not dataset:
                raise ValueError("no dataset found")

            processing_rule = db.session.query(DatasetProcessRule).filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id).first()
            index_type = dataset_document.doc_form
            index_processor = IndexProcessorFactory(index_type).init_index_processor()

            text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
            documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict())

            self._load_segments(dataset, dataset_document, documents)
            self._load(index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents)

        except DocumentIsPausedException:
            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
        except ProviderTokenNotInitError as e:
            dataset_document.indexing_status = 'error'
            dataset_document.error = str(e.description)
            dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        except ObjectDeletedError:
            logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
        except Exception as e:
            logging.exception("consume document failed")
            dataset_document.indexing_status = 'error'
            dataset_document.error = str(e)
            dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()

    def _load_segments(self, dataset, dataset_document, documents):
        doc_store = DatasetDocumentStore(dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id)
        keyword_handler = JiebaKeywordTableHandler()

        for document in documents:
            keywords = keyword_handler.extract_keywords(document.page_content)
            document.keywords = list(keywords)

        doc_store.add_documents(documents)

        cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
        self._update_document_index_status(document_id=dataset_document.id, after_indexing_status="indexing", extra_update_params={
            DatasetDocument.cleaning_completed_at: cur_time,
            DatasetDocument.splitting_completed_at: cur_time,
        })

        self._update_segments_by_document(dataset_document_id=dataset_document.id, update_params={
            DocumentSegment.status: "indexing",
            DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
        })

    def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
        DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params, synchronize_session=False)
        db.session.commit()

Optimized JiebaKeywordTableHandler Class

import re
from typing import Optional
import jieba
from jieba.analyse import default_tfidf
from your_project.keyword.jieba.stopwords import STOPWORDS
from pydantic import BaseModel

class KeywordTableConfig(BaseModel):
    max_keywords_per_chunk: int = 10

class JiebaKeywordTableHandler:

    def __init__(self):
        default_tfidf.stop_words = STOPWORDS
        self.stopwords_set = set(STOPWORDS)
        self.word_pattern = re.compile(r"\w+")
        self.config = KeywordTableConfig()

    def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = None) -> set[str]:
        if max_keywords_per_chunk is None:
            max_keywords_per_chunk = self.config.max_keywords_per_chunk

        keywords = jieba.analyse.extract_tags(
            sentence=text,
            topK=max_keywords_per_chunk,
        )

        return set(self._expand_tokens_with_subtokens(keywords))

    def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
        results = set()
        for token in tokens:
            results.add(token)
            sub_tokens = self.word_pattern.findall(token)
            if len(sub_tokens) > 1:
                results.update({w for w in sub_tokens if w not in self.stopwords_set})

        return results

By implementing these optimizations, you should see improved performance in the knowledge base construction process, especially when dealing with large-scale paragraph data [1][2].

To continue talking to Dosu, mention @dosu.

crazywoola commented 3 weeks ago

Feel free to submit a PR for this :)

codingMrHu commented 3 weeks ago

Hello @codingMrHu! I'm here to help you with bugs, questions, and becoming a contributor. Let's tackle this issue together while we wait for a human maintainer.

To address the performance issues in the IndexingRunner().run() method, particularly with the _update_segment_keywords() function, you can optimize the process by extracting and saving keywords during the segment insertion phase. Here are some strategies to improve performance:

  1. Extract Keywords During Segment Insertion: Modify the _load_segments method to include keyword extraction and update the segments with keywords during insertion.
  2. Batch Processing: Process documents in batches to reduce the overhead of database operations.
  3. Parallel Processing: Utilize multi-threading or multi-processing to handle multiple segments concurrently.

Here is an example of how you can implement these optimizations:

Modified IndexingRunner Class

import concurrent.futures
import datetime
import logging
from your_project.models import Dataset, DatasetProcessRule, DocumentSegment, DatasetDocumentStore
from your_project.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler

class IndexingRunner:

    def __init__(self):
        self.storage = storage
        self.model_manager = ModelManager()

    def run(self, dataset_documents: list[DatasetDocument]):
        """Run the indexing process."""
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(self._process_document, dataset_document) for dataset_document in dataset_documents]
            for future in concurrent.futures.as_completed(futures):
                try:
                    future.result()
                except Exception as e:
                    logging.exception("Error processing document: %s", e)

    def _process_document(self, dataset_document):
        try:
            dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
            if not dataset:
                raise ValueError("no dataset found")

            processing_rule = db.session.query(DatasetProcessRule).filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id).first()
            index_type = dataset_document.doc_form
            index_processor = IndexProcessorFactory(index_type).init_index_processor()

            text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
            documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict())

            self._load_segments(dataset, dataset_document, documents)
            self._load(index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents)

        except DocumentIsPausedException:
            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
        except ProviderTokenNotInitError as e:
            dataset_document.indexing_status = 'error'
            dataset_document.error = str(e.description)
            dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        except ObjectDeletedError:
            logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
        except Exception as e:
            logging.exception("consume document failed")
            dataset_document.indexing_status = 'error'
            dataset_document.error = str(e)
            dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()

    def _load_segments(self, dataset, dataset_document, documents):
        doc_store = DatasetDocumentStore(dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id)
        keyword_handler = JiebaKeywordTableHandler()

        for document in documents:
            keywords = keyword_handler.extract_keywords(document.page_content)
            document.keywords = list(keywords)

        doc_store.add_documents(documents)

        cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
        self._update_document_index_status(document_id=dataset_document.id, after_indexing_status="indexing", extra_update_params={
            DatasetDocument.cleaning_completed_at: cur_time,
            DatasetDocument.splitting_completed_at: cur_time,
        })

        self._update_segments_by_document(dataset_document_id=dataset_document.id, update_params={
            DocumentSegment.status: "indexing",
            DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
        })

    def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
        DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params, synchronize_session=False)
        db.session.commit()

Optimized JiebaKeywordTableHandler Class

import re
from typing import Optional
import jieba
from jieba.analyse import default_tfidf
from your_project.keyword.jieba.stopwords import STOPWORDS
from pydantic import BaseModel

class KeywordTableConfig(BaseModel):
    max_keywords_per_chunk: int = 10

class JiebaKeywordTableHandler:

    def __init__(self):
        default_tfidf.stop_words = STOPWORDS
        self.stopwords_set = set(STOPWORDS)
        self.word_pattern = re.compile(r"\w+")
        self.config = KeywordTableConfig()

    def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = None) -> set[str]:
        if max_keywords_per_chunk is None:
            max_keywords_per_chunk = self.config.max_keywords_per_chunk

        keywords = jieba.analyse.extract_tags(
            sentence=text,
            topK=max_keywords_per_chunk,
        )

        return set(self._expand_tokens_with_subtokens(keywords))

    def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
        results = set()
        for token in tokens:
            results.add(token)
            sub_tokens = self.word_pattern.findall(token)
            if len(sub_tokens) > 1:
                results.update({w for w in sub_tokens if w not in self.stopwords_set})

        return results

By implementing these optimizations, you should see improved performance in the knowledge base construction process, especially when dealing with large-scale paragraph data [1][2].

To continue talking to Dosu, mention @dosu.

keywords was inserted to document_segments not documents. it should be add extract_keywords to DatasetDocumentStore.add_documents

def add_documents(
            self, docs: Sequence[Document], allow_update: bool = True
    ) -> None:
        max_position = db.session.query(func.max(DocumentSegment.position)).filter(
            DocumentSegment.document_id == self._document_id
        ).scalar()

        if max_position is None:
            max_position = 0
        embedding_model = None
        if self._dataset.indexing_technique == 'high_quality':
            model_manager = ModelManager()
            embedding_model = model_manager.get_model_instance(
                tenant_id=self._dataset.tenant_id,
                provider=self._dataset.embedding_model_provider,
                model_type=ModelType.TEXT_EMBEDDING,
                model=self._dataset.embedding_model
            )
        start = time.perf_counter()
        logging.info(f"start add segment_document:{self._document_id},len docs:{len(docs)}")
        keyword_processor = Keyword(self._dataset)._keyword_processor
        for doc in docs:
            if not isinstance(doc, Document):
                raise ValueError("doc must be a Document")

            segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id'])

            # NOTE: doc could already exist in the store, but we overwrite it
            if not allow_update and segment_document:
                raise ValueError(
                    f"doc_id {doc.metadata['doc_id']} already exists. "
                    "Set allow_update to True to overwrite."
                )

            # calc embedding use tokens
            if embedding_model:
                tokens = embedding_model.get_text_embedding_num_tokens(
                    texts=[doc.page_content]
                )
            else:
                tokens = 0

            if not segment_document:
                max_position += 1

                segment_document = DocumentSegment(
                    tenant_id=self._dataset.tenant_id,
                    dataset_id=self._dataset.id,
                    document_id=self._document_id,
                    index_node_id=doc.metadata['doc_id'],
                    index_node_hash=doc.metadata['doc_hash'],
                    position=max_position,
                    content=doc.page_content,
                    keywords=keyword_processor.extract_keywords(doc.page_content),
                    word_count=len(doc.page_content),
                    tokens=tokens,
                    enabled=False,
                    created_by=self._user_id,
                )
                if doc.metadata.get('answer'):
                    segment_document.answer = doc.metadata.pop('answer', '')

                db.session.add(segment_document)
            else:
                segment_document.content = doc.page_content
                segment_document.keywords = keyword_processor.extract_keywords(doc.page_content)
                if doc.metadata.get('answer'):
                    segment_document.answer = doc.metadata.pop('answer', '')
                segment_document.index_node_hash = doc.metadata['doc_hash']
                segment_document.word_count = len(doc.page_content)
                segment_document.tokens = tokens