astronomer / ask-astro

An end-to-end LLM reference implementation providing a Q&A interface for Airflow and Astronomer
https://ask.astronomer.io/
Apache License 2.0
192 stars 47 forks source link

Switch to context manager for batch import #173

Closed mpgreg closed 10 months ago

mpgreg commented 10 months ago

https://github.com/astronomer/ask-astro/blob/515f3386c4eac8aa4ddcdc3ad12c46b52e4aad8a/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py#L229-L263

Weaviate recommends ingest using context manager https://weaviate.io/developers/weaviate/manage-data/import. Current code gives a warning to use CM.

mpgreg commented 10 months ago

Suggested fix:

def batch_ingest(
        self,
        df: pd.DataFrame,
        class_name: str,
        uuid_column: str,
        existing: str,
        vector_column: str | None = None,
        tenant: str | None = None,
        batch_params: dict = {},
        verbose: bool = False,

    ) -> (list, Any):
        """
        Processes the DataFrame and batches the data for ingestion into Weaviate.  Ingest in cludes 
        a callback function to handle (upsert, rollback, etc.) errors in batch ingest. If a user sets
        'callback' in batch_params the user expected to handle these errors.

        :param df: DataFrame containing the data to be ingested.
        :param class_name: The name of the class in Weaviate to which data will be ingested.
        :param uuid_column: Name of the column containing the UUID.
        :param vector_column: Name of the column containing the vector data.
        :param tenant:  The tenant to which the object will be added.
        :param batch_params: Parameters for batch configuration.
        :param existing: Strategy to handle existing data ('skip', 'replace', 'upsert' or 'error').
        :param verbose: Whether to log verbose output.
        :return: List of any objects that failed to be added to the batch.
        """

        if not batch_params.get("callback"):
            batch_params.update({"callback": self.process_batch_errors})

        self.client.batch.configure(**batch_params)
        self.batch_errors = []

        with self.client.batch as batch:
            for row_id, row in df.iterrows():

                data_object = row.to_dict()
                uuid = data_object.pop(uuid_column)
                vector = data_object.pop(vector_column, None)

                try:
                    if self.client.data_object.exists(uuid=uuid, class_name=class_name, tenant=tenant) is True:
                        if existing == "error":
                            raise AirflowException(f"Ingest of UUID {uuid} failed.  Object exists.")
                        elif existing == "skip":
                            if verbose is True:
                                self.logger.info(f"UUID {uuid} exists.  Skipping.")
                            continue
                        elif existing == "replace":
                            # Default for weaviate is replace existing
                            if verbose is True:
                                self.logger.info(f"UUID {uuid} exists.  Overwriting.")

                except Exception as e:
                    if isinstance(e, AirflowException):
                        self.logger.error(e)
                        self.batch_errors.append({"uuid": uuid, "errors": [str(e)]})
                        break
                    else:
                        self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
                        self.batch_errors.append({"uuid": uuid, "errors": [str(e)]})
                        continue

                try:
                    added_row = batch.add_data_object(
                        class_name=class_name, uuid=uuid, data_object=data_object, vector=vector, tenant=tenant
                        )
                    if verbose is True:
                        self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.")

                except Exception as e:
                    if verbose:
                        self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
                    self.batch_errors.append({"uuid": uuid, "errors": [str(e)]})

        return self.batch_errors
sunank200 commented 10 months ago

@mpgreg did we not implement this before? You had asked to remove the context manager for #132