Greptile: Add docstrings to backend/services/genai.py file #5

DanielDaCosta opened 3 months ago

greptile-apps[bot] commented 3 months ago

Disclaimer: This comment was generated by a bot Add docstrings to all classes and methods in the backend/services/genai.py file.

class GeminiProcessor:
    Processor class for handling operations related to the Gemini model.

        model_name (str): The name of the model to be used.
        project (str): The project associated with the model.
    def __init__(self, model_name, project) -> None:
        Initializes the GeminiProcessor with the specified model name and project.

            model_name (str): The name of the model to be used.
            project (str): The project associated with the model.
        self.model = VertexAI(model_name=model_name, project=project)

    def generate_document_summary(self, documents: list, **args):
        Generates a summary for a list of documents.

            documents (list): A list of documents to summarize.
            **args: Additional arguments for the summarization chain.

            str: The generated summary.
        chain_type = "map_reduce" if len(documents) > 10 else "stuff"

        chain = load_summarize_chain(

        return chain.run(documents)

    def count_total_tokens(self, docs: list):
        Counts the total number of billable characters in a list of documents.

            docs (list): A list of documents to count tokens for.

            int: The total number of billable characters.
        temp_model = GenerativeModel("gemini-1.0-pro")

        total = 0
        logger.info("Counting total billable characters...")
        for doc in tqdm(docs):
            total += temp_model.count_tokens(doc.page_content).total_billable_characters

        return total

    def get_model(self):
        Retrieves the model associated with the processor.

            VertexAI: The model associated with the processor.
        return self.model

class YoutubeProcessor:
    Processor class for handling operations related to YouTube videos.

        genai_processor (GeminiProcessor): An instance of GeminiProcessor.
    def __init__(self, genai_processor: GeminiProcessor) -> None:
        Initializes the YoutubeProcessor with the specified GeminiProcessor instance.

            genai_processor (GeminiProcessor): An instance of GeminiProcessor.
        self.text_splitter = RecursiveCharacterTextSplitter(
        self.GeminiProcessor = genai_processor

    def retrieve_youtube_documents(self, video_url: str, verbose = False):
        Retrieves and processes documents from a YouTube video URL.

            video_url (str): The URL of the YouTube video.
            verbose (bool): Whether to log detailed information.

            list: A list of processed documents.
        loader = YoutubeLoader.from_youtube_url(
        docs = loader.load()
        result = self.text_splitter.split_documents(docs)
        author = result[0].metadata['author']
        length = result[0].metadata['length']
        title = result[0].metadata['title']
        total_size = len(result)

        total_billable_characters = self.GeminiProcessor.count_total_tokens(result)

        if verbose:
            logger.info(f"{author} \n {length}\n{title}\n{total_size}\n{total_billable_characters}")

        return result

    def format_processed_concepts(self, processed_concepts):
        Formats processed concepts into a list of dictionaries.

            processed_concepts (list): A list of processed concepts.

            list: A list of formatted concepts.
        combined_dict = {}

        for d in processed_concepts:

        # Convert combined dictionary into the required format
        formatted_list = [{"term": key, "definition": value} for key, value in combined_dict.items()]
        return formatted_list

    def find_key_concepts(self, documents: list, sample_size: int=0, verbose=False):
        Finds key concepts in a list of documents.

            documents (list): A list of documents to process.
            sample_size (int): The number of documents per sample.
            verbose (bool): Whether to log detailed information.

            list: A list of key concepts and their definitions.
        if sample_size > len(documents):
            raise ValueError("Group size is larger than the number of documents")

        if sample_size == 0:
            sample_size = len(documents) // 5
            if verbose:
                logging.info(f"No sample size specified. Setting number of documents per sample as 5. Sample size: {sample_size}")

        num_docs_per_group = len(documents) // sample_size + (len(documents)% sample_size > 0)

        if num_docs_per_group > 10:
            raise ValueError("Each group has more than 10 documents and output quality will be degraded"
            "significantly. Increase the sample_size parameter to reduce the number of documents per group.")
        elif num_docs_per_group > 5:
            logging.warn("Each group has more than 5 documents and output quality is likely to be degraded."
                         "Consider increasing the sample size")

        groups = [documents[i:i+num_docs_per_group] for i in range(0, len(documents), num_docs_per_group)]

        batch_concepts = []
        batch_cost = 0
        logger.info("Finding key concepts...")
        for group in tqdm(groups):

            group_content = ""
            for doc in group:
               group_content += doc.page_content

            if not group_content:
               logger.warning("No content to process for this group.")

            prompt = PromptTemplate(template="""
                     Find the key concepts and their definitions from the following text:
                     Respond only in clean JSON format without any labels or additional text. The output exactly should look like this:
                     {"concept1": "definition1", "concept2": "definition2"}
                     """, input_variables=["text"])

            chain = prompt | self.GeminiProcessor.model

                output_concept = chain.invoke({"text": group_content})

                output_concept = output_concept.replace("```json", "").replace("```", "").strip()

            except Exception as e:
                logger.error(f"Failed to find concepts for group: {e}")

            processed_concepts = [json.loads(concept) for concept in batch_concepts]

            if verbose:
                total_input_char = len(group_content)
                total_input_cost = (total_input_char/1000) * 0.000125
                logging.info(f"Running chain on{len(group)} documents")
                logging.info(f"Total input characters: {total_input_char} ")
                logging.info(f"Total cost: {total_input_cost} ")   
                total_output_char = len(output_concept)
                total_output_cost = (total_output_char/1000) * 0.000125
                logging.info(f"Total output characters: {total_output_char} ")
                logging.info(f"Total output cost: {total_output_cost} ")   

        return self.format_processed_concepts(processed_concepts)



