apache / airflow

Apache Airflow - A platform to programmatically author, schedule, and monitor workflows
https://airflow.apache.org/
Apache License 2.0
36.49k stars 14.13k forks source link

Add "job_clusters" in template_fields - DatabricksWorkflowTaskGroup #42438

Open sushil-louisa opened 3 days ago

sushil-louisa commented 3 days ago

Description

It would be great if we can allow CreateDatabricksWorkflowOperator to support Jinja templating for job_clusters attribute. This will allow us to dynamically render cluster configuration at runtime based on the context, such as spark env variables or spark configurations.

https://github.com/apache/airflow/blob/e0bddbc438872277ca8de7fb794285cf546ddc0b/airflow/providers/databricks/operators/databricks_workflow.py#L96

Use case/motivation

I have an Airflow Dag. In this dag, I am trying to use output of upstream task in job_clusters attribute while defining the DatabricksWorkflowTaskGroup. In current implementation, output of upstream task doesn't get jinjaified at runtime.

Sample:


dag = DAG(
    dag_id="example_databricks_workflow",
    start_date=datetime(2023, 1, 1),
    schedule=None,
    catchup=False,
    tags=["example", "databricks"],
)
with dag:

 fetch_config = PythonOperator(
        task_id='fetch_config',
        python_callable=fetch_config_func,
        provide_context=True,
    )

  task_group = DatabricksWorkflowTaskGroup(
        group_id=f"test_workflow",
        job_clusters=[
          {
        "job_cluster_key": "Shared_job_cluster",
        "new_cluster": {
            "cluster_name": "",
            "spark_version": "11.3.x-scala2.12",
            "aws_attributes": {
                ...
            },
            "node_type_id": "i3.xlarge",
            # Pass output of fetch_config task as spark env variables.
            "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3", "CONFIG": f"{fetch_config.output}"},
            "enable_elastic_disk": False,
            "data_security_mode": "LEGACY_SINGLE_USER_STANDARD",
            "runtime_engine": "STANDARD",
            "num_workers": 8,
          }
        ]
    )
    with task_group:

        task_operator_nb_1 = DatabricksTaskOperator(
            task_id="nb_1",
            databricks_conn_id="databricks_conn",
            job_cluster_key="Shared_job_cluster",
            task_config={
                "notebook_task": {
                    "notebook_path": "/Shared/Notebook_1",
                    "source": "WORKSPACE",
                },
                "libraries": [
                    {"pypi": {"package": "Faker"}},
                    {"pypi": {"package": "simplejson"}},
                ],
            },
        )

        sql_query = DatabricksTaskOperator(
            task_id="sql_query",
            databricks_conn_id="databricks_conn",
            task_config={
                "sql_task": {
                    "query": {
                        "query_id": QUERY_ID,
                    },
                    "warehouse_id": WAREHOUSE_ID,
                }
            },
        )

        task_operator_nb_1 >> sql_query

fetch_config >> task_group

Related issues

No response

Are you willing to submit a PR?

Code of Conduct

boring-cyborg[bot] commented 3 days ago

Thanks for opening your first issue here! Be sure to follow the issue template! If you are willing to raise PR to address this issue please do so, no need to wait for approval.