astronomer / astro-sdk

Astro SDK allows rapid and clean development of {Extract, Load, Transform} workflows using Python and SQL, powered by Apache Airflow.
https://astro-sdk-python.rtfd.io/
Apache License 2.0
350 stars 43 forks source link

Table naming when expanding a LoadFileOperator #1414

Open ReadytoRocc opened 1 year ago

ReadytoRocc commented 1 year ago

Please describe the feature you'd like to see The LoadFileOperator accepts output_table and input_file arguments. When expanding over a set of files, you may want to load them with the exact same table definition, but to different paths (e.g. 5 files for the same data feed). This works by expanding over input_file (list), but _tmp table names are generated for the loaded tables. To name the tables, we need to generate a config combining the desired table name, table object, and file object to expand over. This requires a separate task, or using task methods such as map & zip to prepare the load configurations.

Describe the solution you'd like I would like the ability to directly pass a table_name_parsing_function. The table_name_parsing_function would take the file name from input_file as an argument, and allow the user to parse this into a table name for that file. By default, We could transform the file name into an ANSI SQL compliant table name. As an example, we could combine the file name, run id, and task index.

Are there any alternatives to this feature? Noted above in the description.

Acceptance Criteria

sunank200 commented 1 year ago

@ReadytoRocc could you share some example DAGs as per the description above?

ReadytoRocc commented 1 year ago

@ReadytoRocc could you share some example DAGs as per the description above?

@sunank200 - please see the example below:

from airflow.decorators import dag, task
from airflow.exceptions import AirflowSkipException

from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.operators.python import get_current_context

from astro import sql as aql
from astro.sql.table import Table, Metadata
from astro.files import File
from astro.sql.operators.load_file import LoadFileOperator as LoadFile

from datetime import datetime
import pandas as pd

@task(task_id="parse_load_configs")
def parse_load_configs_func(output_table_dataset, output_table_conn_id, file_list):
    import os

    load_configs = []

    for file in file_list:
        table = Table(
            metadata=Metadata(
                schema=output_table_dataset,
            ),
            conn_id=output_table_conn_id,
            temp=False,
        )
        table.name = os.path.basename(file.path).split(".")[0]

        load_configs.append({"output_table": table, "input_file": file})

    return load_configs

@task(task_id="scan_gcs")
def gcs_scan_func(
    gcp_conn_id, bucket_name, prefix=None, delimiter=None, regex=None, **kwargs
):
    import re

    gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id)

    timespan_start = kwargs["data_interval_start"]
    timespan_end = kwargs["data_interval_end"]

    print(f"Scaning between {timespan_start} and {timespan_end}.")
    files = gcs_hook.list_by_timespan(
        bucket_name=bucket_name,
        prefix=prefix,
        delimiter=delimiter,
        timespan_start=timespan_start,
        timespan_end=timespan_end,
    )

    if regex:
        _files = []
        re_com = re.compile(regex)
        for file in files:
            if re_com.fullmatch(file):
                _file = f"gs://{bucket_name}/{file}"
                _files.append(File(path=_file, conn_id=gcp_conn_id))
        files = _files

    if len(files) == 0:
        raise AirflowSkipException("No Files found, skipping.")
    else:
        return files

# Variables
BIGQUERY_DATASET = ""
GCP_CONN_ID = ""
GCS_BUCKET = ""

@dag(
    schedule_interval="* * * * *",
    start_date=datetime(2022, 12, 5),
    catchup=False,
)
def bq_sdk():
    gcs_scan_task = gcs_scan_func(
        gcp_conn_id=GCP_CONN_ID, bucket_name=GCS_BUCKET, regex=r".*\.csv"
    )

    parse_load_configs_func_task = parse_load_configs_func(
        output_table_dataset=BIGQUERY_DATASET,
        output_table_conn_id=GCP_CONN_ID,
        file_list=gcs_scan_task,
    )

    load_gcs_to_bq = LoadFile.partial(
        task_id="load_gcs_to_bq",
        use_native_support=True,
    ).expand_kwargs(parse_load_configs_func_task)

dag_obj = bq_sdk()