astronomer / astronomer-cosmos

Run your dbt Core projects as Apache Airflow DAGs and Task Groups with a few lines of code
https://astronomer.github.io/astronomer-cosmos/
Apache License 2.0
605 stars 153 forks source link

Introducing composability in the middle layer of Cosmos's API #895

Open dwreeves opened 6 months ago

dwreeves commented 6 months ago

TLDR

Sorry, this is long!!!! But there is a lot to cover. This is a massive proposal to reorganize massive chunks of Cosmos.

Here's a TLDR:

Introduction

At my last company, I built a Dbt Airflow integration prior to the existence of Cosmos. At my current company, we are using Cosmos.

There are a handful of features I really miss from that custom integration. The big one is, when I have a ref() to a task in a different DAG, we'd automatically create an external task sensor to the dbt node in that separate DAG. I don't think Cosmos should necessarily implement that feature, but I do think right now it would be extremely complicated to do so as a custom feature in an Airflow deployment.

Background: Issues with custom implementations today

Right now, the only real way to dive deep into Cosmos's internals is to use the node_converters API. This is typed as a dict[DbtResourceType, Callable[..., Any]], but a more accurate typing would be dict[DbtResourceType, NodeConverterProtocol] with the following definition for NodeConverterProtocol:

class NodeConverterProtocol(Protocol):
    __call__(
            self,
            dag: DAG,
            task_group: TaskGroup | None,
            node: DbtNode,
            execution_mode: ExecutionMode,
            task_args: dict[str, Any],
            test_behavior: TestBehavior,
            test_indirect_selection: TestIndirectSelection,
            on_warning_callback: Callable[..., Any] | None,
            **kwargs: Any
    ) -> BaseOperator | TaskGroup | None:
        ...

Where the kwargs don't do anything.

This API is hard to work with, to be honest. Imagine for example you want to use a variable in the DbtNode's config to implement custom functionality. You'd have to do something like this:

from cosmos.airflow.graph import generate_task_or_group

def custom_callback(
    dag: DAG,
    task_group: TaskGroup | None,
    node: DbtNode,
    execution_mode: ExecutionMode,
    task_args: dict[str, Any],
    test_behavior: TestBehavior,
    test_indirect_selection: TestIndirectSelection,
    on_warning_callback: Callable[..., Any] | None,
    **kwargs: Any
):
    if "retries" in node.config:
        task_args["retries"] = node.config["retries"]
    return generate_task_or_group(
        dag=dag,
        task_group=task_group,
        node=node,
        execution_mode=execution_mode,
        task_args=task_args,
        test_behavior=test_behavior,
        test_indirect_selection=test_indirect_selection,
        on_warning_callback=on_warning_callback,
        **kwargs
    )

node_converters = {k: custom_callback for k in DbtResourceType.__members__.values()}

This is essentially a method override, but with a few oddities.

First, you're not actually overriding a method, you're wrapping a function.

Second, you need to implement the custom node converter for each node type with the dict. Instead of the node converter callback dispatching the logic by node type, the node converter callback itself is dispatched. This is despite the fact that the node type is already inside the args (inside node.resource_type) and the only real difference across node types is the Dbt operator type.

Third, the API is very large. There are 8 total args, including one feature (on_warning_callback) that is very niche and is just an operator arg for tests. on_warning_callback specifically is not very well future-proofed, since if more test-specific operator args are added, the API currently demands that these be laid out as new kwargs to the node converter protocol.

Fourth, many of the args are not (reasonably) task-specific; inside of a single DbtTaskGroup or DbtDag, a lot of these variables can be accessed via shared state in a DbtToAirflowConverter object via self:

For the latter two, you can imagine a user who wants to dispatch custom execution mode or test behavior logic in a node-specific way, rather than using a shared implementation across each task in the node. These niche situations for dispatching implementations can and should be supported, but it doesn't need to be supported via expanding the function signature; they can be implemented via method overrides.

After more dissecting of the Cosmos API, there are other places in the API where things don't feel quite right. For example, DbtGraph().load() uses two args: method and execution_mode. What doesn't make sense is that these args are already variables available in the state of the DbtGraph object, since the DbtGraph is initialized with the render_config and execution_config. I think it would make sense for the load method to be an optional kwarg that can override the config, rather than duplicate the config by the callee. I'm slightly picking on DbtGraph().load() because (A) this is a common theme and (B) the idea of using kwargs to override config state, rather than just be duplicative of config state, is a good pattern for supporting subclassing. In my Proposed API, I do this a couple times: a method is called, by default, without any args/kwargs, but the implementation supports args/kwargs. The idea here is by default, all args should originate from config state, but another callee has an interface for overriding the config without actually mutating the config's state.

Proposed API

I think that a more composable interface should be accessible via subclassing of the DbtToAirflowConverter.

Here is a rough draft proposal. I've excluded a lot of the code necessary to make this work; I've only focused on the important bits, and added in-line comments about the philosophies behind these changes.

This code is very far from done, and I've skipped over a lot of details of the implementation. I put a lot of effort into defining the highest-most layer of the API, and defined the relevant methods for the middle-layer of the API, but I leave a few things beyond that up to imagination.

A lot of this will be controversial, I am aware. This is not a final proposal. I hope that this can be discussed and iterated on with others!

from __future__ import annotations
from typing import *
# from airflow.models import DAG
# from airflow.utils.task_group import TaskGroupContext
# from airflow.models.dag import DagContext
from airflow.models import BaseOperator
from cosmos.operators.base import AbstractDbtBaseOperator
from airflow.utils.task_group import TaskGroup

from cosmos.config import ProjectConfig, ProfileConfig, ExecutionConfig, RenderConfig
from airflow.utils.log.logging_mixin import LoggingMixin
from cosmos.dbt.graph import (
    # DbtGraph,
    DbtNode,
    LoadMode,
)

class DbtGraph:

    load_method_mapping: dict[LoadMode, Callable[[], None]] = {}

    def __init__(
            self,
            project_config: ProjectConfig | None = None,
            profile_config: ProfileConfig | None = None,
            execution_config: ExecutionConfig | None = None,
            render_config: RenderConfig | None = None,
            **kwargs
    ):
        super().__init__()
        # 'project' is inconsistent naming.
        # Elsewhere in the code, it is `project_config`.
        # The variable should be the same everywhere.
        if "project" in kwargs:
            if project_config is not None:
                raise TypeError("Cannot specify both 'project' and 'project_config'."
                                " Please use the project_config arg.")
            import warnings
            warnings.warn(
                "`project` is deprecated, and has been renamed to `project_config`."
                " Please use `project_config` instead.",
                DeprecationWarning,
                stacklevel=2
            )
        self.project_config = project_config

        if project_config is None:
            raise TypeError("'project_config' must be specified.")

        # It's not clear why the `ProfileConfig` requires a profile name
        # and target name. These could just use arbitrary defaults,
        # since Cosmos constructs the profile.
        self.profile_config = profile_config or ProfileConfig()
        self.execution_config = execution_config or ExecutionConfig()
        self.render_config = render_config or RenderConfig()

        # Before, load both dispatched the load method _and_ was the implementation
        # for the automatic load method.
        # Automatic mapping should be its own method, distinct from load(), which
        # frees up load() to _only_ be a dispatcher of the load method. This achieves
        # separation of concerns.
        default_load_method_mapping = {
            LoadMode.AUTOMATIC: self.load_with_automatic_method,
            LoadMode.CUSTOM: self.load_via_custom_parser,
            LoadMode.DBT_LS: self.load_via_dbt_ls,
            LoadMode.DBT_LS_FILE: self.load_via_dbt_ls_file,
            LoadMode.DBT_MANIFEST: self.load_from_dbt_manifest,
        }
        # It looks like this in case load methods are set at the class level,
        # i.e. `default_load_method_mapping` sets defaults, but does not override
        # anything a user might set at the class level.
        default_load_method_mapping.update(self.load_method_mapping)
        self.load_method_mapping = default_load_method_mapping

        self.is_loaded: bool = False

        self.nodes: dict[str, DbtNode] = {}
        self.filtered_nodes: dict[str, DbtNode] = {}

    # Handle deprecations responsibly

    @property
    def project(self) -> ProjectConfig:
        import warnings
        warnings.warn(
            "`project` is deprecated, and has been renamed to `project_config`."
            " Please use `project_config` instead.",
            DeprecationWarning,
            stacklevel=2
        )
        return self.project_config

    @project.setter
    def project(self, v: ProjectConfig) -> None:
        import warnings
        warnings.warn(
            "`project` is deprecated, and has been renamed to `project_config`."
            " Please use `project_config` instead.",
            DeprecationWarning,
            stacklevel=2
        )
        self.project_config = v

    def load(
            self,
            method: LoadMode | None = None,
            reload: bool = False
    ) -> dict[str, DbtNode]:
        if not reload and self.is_loaded:
            return self.filtered_nodes
        # Load only dispatches;
        # does not implicitly have the automatic load mode inside.
        if method is None:
            method = self.render_config.load_method
        callback = self.load_method_mapping[method]
        # Before the load() method only mutated DbtGraph's state.
        # It's possible that DbtGraph should not rely on mutated state.
        return callback()

    def load_with_automatic_method(self) -> None: ...
    def load_via_custom_parser(self) -> None: ...
    def load_via_dbt_ls(self) -> None: ...
    def load_via_dbt_ls_file(self) -> None: ...
    def load_from_dbt_manifest(self) -> None: ...

# Right now Cosmos uses logger = logging.getLogger(),
# but objects in Airflow tend to use the LoggingMixin.
class DbtToAirflowConverter(LoggingMixin):

    # Right now the DbtGraph is purely internal.
    # This makes it much easier to implement custom DbtGraph parsing behavior.

    dbt_graph_class: Type[DbtGraph] = DbtGraph

    # Note that TaskGroup and DAG were removed from the args
    # for __init__().
    # This is because these are available via the following code:
    #
    # dag: DAG | None
    # task_group: TaskGroup | None
    # if isinstance(self, DAG):
    #     dag = self
    #     task_group = TaskGroupContext.get_current_task_group(dag)
    # elif isinstance(self, TaskGroup):
    #     dag = task_group.dag or DagContext.get_current_dag()
    #     task_group = self
    # else:
    #     dag = DagContext.get_current_dag()
    #     task_group = TaskGroupContext.get_current_task_group(dag)
    #
    # The situations where this doesn't work are incredible niche, and
    # more importantly for our purposes, such situations aren't even
    # practically usable or supported in the current API as it stands.

    def __init__(
            self,
            project_config: ProjectConfig,
            profile_config: ProfileConfig | None = None,
            execution_config: ExecutionConfig | None = None,
            render_config: RenderConfig | None = None,
            operator_args: dict[str, Any] | None = None,
            eagarly_build_dbt_graph: bool = True,
            eagarly_build_airflow_graph: bool = True,
            **kwargs
    ):
        super().__init__()
        # The "configs" should be assigned to the state of the DAG,
        # so that they are accessible via other method calls without
        # requiring them to be assigned

        self.project_config = project_config
        # It's not clear why the `ProfileConfig` requires a profile name
        # and target name. These could just use arbitrary defaults,
        # since Cosmos constructs the profile.
        self.profile_config = profile_config or ProfileConfig()
        self.execution_config = execution_config or ExecutionConfig()
        self.render_config = render_config or RenderConfig()
        self.operator_args = operator_args

        # `on_warning_callback` feels out of place.
        # It is a single kwarg devoted to a very niche feature.
        #
        # This isn't necessarily bad on its own, but the Cosmos API is very
        # compressed into "config" objects, as oppossed to being "flattened."
        # so this pattern of passing this kwarg is not consistent with the
        # rest of the API.
        #
        # I feel like it should be deprecated and moved into `execution_config`.
        if "on_warning_callback" in kwargs:
            import warnings
            warnings.warn(
                "`on_warning_callback` is deprecated."
                " Please use execution_config.on_warning_callback instead.",
                DeprecationWarning,
                stacklevel=2
            )
        self._on_warning_callback = kwargs.pop("on_warning_callback", None)

        # Build Airflow graphs should not require passing any variables
        # since all the state it relies on should be inside `self`.
        #
        # Also, this is designed to be overridden! Logic relating to
        # initialization of the class and building of the graph should be
        # properly separated. Thish should not just be seen as a continuation
        # of the __init__ call with an arbitrary boundary. The boundary line
        # is initialization of non-graph state, and production of the
        # side-effect of the produced graph.
        #
        # Lastly, building and loading should be potentially deferrable.
        # This proposal allows users to do things like this:
        #
        # with TaskGroup("my_task_group") as tg:
        #    # Do things before the task_dict
        #    task_dict = converter.build_airflow_graph()
        #    # Do things with the task_dict

        if eagarly_build_dbt_graph:
            self.dbt_graph = self.build_dbt_graph()
        if self.dbt_graph and eagarly_build_airflow_graph:
            self.build_airflow_graph()

    @property
    def on_warning_callback(self) -> Callable | None:
        # Deprecated
        return self._on_warning_callback or self.execution_config.on_warning_callback

    def get_operator_args(self) -> dict[str, Any]:
        # This is a slight change to the API in a few ways.
        #
        # Right now, config overrides operator args.
        # I feel like this should be the reverse, though.
        # If a user specify an operator arg that overrides the config,
        # that should sufficiently indicate the user's intent to.

        # Second, execution and rendering are different contexts and may use
        # a different env/vars. THe most notable example is when execution uses
        # a templated variable, e.g. something from xcoms, which could not be
        # used during rendering but only during execution. The ExecutionConfig
        # should have a way to override the project config. If this makes the API
        # too confusing, then call them "overrides", i.e. `dbt_env_overrides` or
        # `dbt_vars_overrides`. Or just call them `dbt_env` and `dbt_vars`. In any
        # case,

        operator_args = self.operator_args.copy() if self.operator_args is not None else {}

        dbt_env = (
            self.project_config.dbt_env.copy()
            if self.project_config.dbt_env is not None
            else {}
        )
        dbt_env.update(self.execution_config.dbt_env_overrides or {})
        dbt_env.update(operator_args.pop("env", {}))

        dbt_vars = (
            self.project_config.dbt_vars.copy()
            if self.project_config.dbt_vars is not None
            else {}
        )
        dbt_vars.update(self.execution_config.dbt_vars_overrides or {})
        dbt_vars.update(operator_args.pop("vars", {}))

        kwargs = {
            "project_dir": self.execution_config.project_path,
            "partial_parse": self.project_config.partial_parse,
            "profile_config": self.profile_config,
            "emit_datasets": self.render_config.emit_datasets,
            "env": dbt_env,
            "vars": dbt_vars,
        }
        kwargs.update(operator_args)
        return kwargs

    def build_dbt_graph(self) -> DbtGraph:

        self.dbt_graph = self.dbt_graph_class(
            project_config=self.project_config,
            render_config=self.render_config,
            execution_config=self.execution_config,
            profile_config=self.profile_config,
        )
        self.dbt_graph.load()
        return self.dbt_graph

    def build_airflow_graph(
            self,
            nodes: dict[str, DbtNode] | None = None
    ) -> None:

        nodes = nodes or self.dbt_graph.load()
        tasks_map: dict[str, TaskGroup | BaseOperator] = {}

        for node_id, dbt_node in nodes.items():
            task_or_group = self.convert_dbt_node_to_airflow(
                node_id=node_id,
                dbt_node=dbt_node,
                operator_args=self.get_operator_args()
            )
            if task_or_group is not None:
                self.log.debug(f"Conversion of <{dbt_node.unique_id}> was successful!")
                tasks_map[node_id] = task_or_group

        for node_id, dbt_node in nodes.items():
            self.create_task_dependencies(
                node_id=node_id,
                dbt_node=dbt_node,
                airflow_node=tasks_map[node_id],
                tasks_map=tasks_map
            )

        # This is where things like TestBehavior.AFTER_ALL get resolved.
        # It's exposed as a method so that it can be overridden.
        tasks_map = self.build_airflow_graph_post_hook(tasks_map=tasks_map)

        return tasks_map

    def convert_dbt_node_to_airflow(
            self,
            node_id: str,
            dbt_node: DbtNode,
            operator_args: dict[str, Any],
    ) -> TaskGroup | BaseOperator | None:
        ...

    def create_dbt_operator(
            self,
            node_id: str,
            dbt_node: DbtNode,
            operator_args: dict[str, Any],
    ) -> AbstractDbtBaseOperator:
        ...

    def build_airflow_graph_post_hook(
            self,
            tasks_map: dict[str, TaskGroup | BaseOperator]
    ) -> dict[str, TaskGroup | BaseOperator]:
        ...

    def create_task_dependencies(
            self,
            node_id: str,
            dbt_node: DbtNode,
            airflow_node: TaskGroup | BaseOperator,
            tasks_map: dict[str, TaskGroup | BaseOperator]
    ) -> None:
        ...

What this enables

I want to show just a small handful of the many cool things that all of these API changes enable for end users!

Custom business logic

Let's say a user has 1000 dbt models, and 3 of them (foo, bar, and baz) need special operator args; for sake of argument, let's say it's just to set retries to 5. The user could then do that with this:

class CustomDbtTaskGroup(DbtTaskGroup):

    def create_dbt_operator(
            self,
            node_id: str,
            dbt_node: DbtNode,
            operator_args: dict[str, Any],
    ) -> AbstractDbtBaseOperator:
        if node_id in {"foo", "bar", "baz"}:
            operator_args["retries"] = 5
        return super().create_dbt_operator(node_id, dbt_node, operator_args)

retries is a little contrived, but you can imagine something less contrived, like a custom profile_config for writing to a database with different credentials than the default (e.g. a Snowflake deployment with a complex IAM structure) or using a custom pool to control concurrency. The profile_config example in particular is noteworthy because there is a proposal right now to allow for custom operator args to be supported via the dbt YAML files, but this approach would not support un-serializable objects, so you wouldn't be able to support a custom profile_config through this API.

External Task Sensors and External DAG knowledge

This is my own white whale. What I really want to do is implement a system where I can split my single dbt project up into multiple DAGs running on multiple schedules, but these DAGs are all aware of tasks in other DAGs.

Say, for example, a user has a custom tagging system that looks like this:

models:
  - name: foo
    tags:
      - airflow_dag_and_task_id:my_airflow_dag.foo_run

What they want is, whenver foo is ref()'d, they want to have that point to an ExternalTaskSensor that points to my_airflow_dag.foo_run.

This can be implemented in just a few lines of code with the following custom code, using the above API changes:

class CustomDbtTaskGroup(DbtTaskGroup):

    def create_task_dependencies(
            self,
            node_id: str,
            dbt_node: DbtNode,
            airflow_node: TaskGroup | BaseOperator,
            tasks_map: dict[str, TaskGroup | BaseOperator]
    ) -> None:
        for parent_node_id in node.depends_on:
            if parent_node_id not in tasks_map and parent_node_id in self.dbt_graph.nodes:
                parent_node = self.dbt_graph.nodes[parent_node_id]
                if parent_node.resource_type not in {DbtResourceType.SOURCE, DbtResourceType.TEST}:
                    sensor = self.create_external_task_sensor(parent_node_id, parent_node)
                    if sensor is not None:
                        tasks_map[parent_node_id] = sensor
        super().create_task_dependencies(node_id, dbt_node, airflow_node, tasks_map)

  def create_external_task_sensor(node_id, dbt_node) -> ExternalTaskSensor | None:
      for tag in dbt_node.tags:
          if tag.startswith("airflow_dag_and_task_id:"):
              external_dag_id, external_task_id = tag.replace("airflow_dag_and_task_id:", "").split(".")
              break
      else:
          return None
      return ExternalTaskSensor(
          task_id=f"{node_id}_sensor",
          external_dag_id=external_dag_id,
          external_task_id=external_task_id
      )

Now the user has a super cool custom feature for their own Cosmos deployment!

Xcoms in dbt vars / env with LoadMode.DBT_LS

One of my changes is, essentially, to walk back the following deprecation enforced in validate_initial_user_config():

# Cosmos 2.0 will remove the ability to pass in operator_args with 'env' and 'vars' in place of ProjectConfig.env_vars and
# ProjectConfig.dbt_vars.

I think it's correct to walk this back, and there is a practical situation for when and why you need to distinguish env and vars in different command invocation contexts: Airflow's own Jinja template rendering.

Right now, using Xcoms with --vars and using LoadMode.DBT_LS do not always mix. The reason is simple: {{ ti.xcoms_pull(task_ids="foo") }} gets rendered appropriately in task execution, but it is a string literal that never touches a Jinja environment when rendering the DAG. More concretely:

A user who is using LoadMode.DBT_LS will possibly want to set a default for the logical date when loading, which parses as a valid ISO date, e.g. "1970-01-01" or date.today(). E.g. the dbt context may do something like {% set year, month, date = get_var("logical_date", "1970-01-01").split("-") %}, which would work in execution but would fail in rendering/loading.

Other niche things

These are more niche, but still cool that they're supported!

The most useful aspect of these "niche things" is that these features open up development of Cosmos itself. Right now, testing an experimental feature change in a production instance of Cosmos is annoying, as there is no easy way to override chunks of the code. Unless you want to put a ton of effort into it, you are mostly just beholden to whatever the features are in the latest release.

So although there are some niche things you can do with the API changes that don't provide direct support to the vast majority of users, they allow people who are developing Cosmos itself to do so much more easily.

Here is one example that uses custom DbtGraph parsing. When I say custom, I mean implementing your own method for parsing, not the custom load method. Say for example you want to experiment with whether adding lru_cache to DbtGraph.load() improves performance somehow. You could write the following:


class CustomDbtGraph(DbtGraph):

    @lru_cache
    def load(*args, **kwargs):
        return super().load(*args, **kwargs)

class CustomDbtTaskGroup(DbtTaskGroup):
    dbt_graph_class = CustomDbtGraph

And now you have code you can test immediately in your production instance. Experimenting with a feature like this in an existing Airflow deployment is next to impossible in the current iteration of Cosmos. But with the API changes I've proposed, there is now an avenue for testing something experimental like this in a real production instance.

Endless possibilities

I actually don't know what all the features are that users want. It's hard to come up with things. I think that's the point, though! Cosmos does not need to anticipate every single user need, it just needs to anticipate what users need to have access to in order to implement their own needs.

Things still left

dosubot[bot] commented 6 months ago

Hey @dwreeves! Great to see you back with more insightful feedback for Cosmos. How have you been?

Thank you for your detailed proposal and the effort you've put into identifying areas for improvement in the Cosmos API. Your suggestions for making the API more composable and accessible via subclassing of the DbtToAirflowConverter are insightful and could potentially enhance the flexibility and usability of Cosmos.

Let's address your points one by one:

  1. Enhanced Configuration Handling: Your suggestion to make args optional kwargs that can override the config is a good one. This would indeed allow all args to originate from config state by default, but also provide an easy interface for overriding the config without actually mutating the config's state. This could be achieved by refactoring the DbtToAirflowConverter constructor and the DbtGraph().load() method to support optional kwargs.

  2. Modular Graph Loading and Task Argument Preparation: Extracting the logic for loading the DBT graph and preparing task arguments into their own methods would indeed improve readability and allow subclasses to modify these processes without overriding the entire constructor. This could be achieved by introducing new methods in the DbtToAirflowConverter class.

  3. Flexible Airflow Graph Building: Moving the call to build_airflow_graph into a separate method would allow subclasses to override or extend the graph building logic. This could be achieved by introducing a new method in the DbtToAirflowConverter class.

  4. Decoupling Validation Logic: Extracting validation into separate, overridable methods would indeed allow subclasses to introduce new validations or modify existing ones more easily. This could be achieved by introducing new methods in the DbtToAirflowConverter class.

However, implementing these changes would require a significant refactoring of the DbtToAirflowConverter class and potentially other parts of the Cosmos codebase. It's also important to consider the potential impact on existing users of Cosmos, as these changes could break existing code that relies on the current API.

As a maintainer, I can't make these changes directly, but I encourage you to submit a pull request with your proposed changes. This would allow the Cosmos community to review your proposal in detail and consider its potential benefits and drawbacks.

In the meantime, you might consider creating a subclass of DbtToAirflowConverter in your own code that implements your proposed changes. This would allow you to take advantage of the increased flexibility and composability without waiting for changes to the Cosmos API.

I hope this helps, and I look forward to seeing your pull request!

Sources

#### About Dosu This response is meant to be useful and save you time. It is **not meant to be a precise solution**, but rather a starting point for your own research. Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant. If you want to continue the conversation, start your reply with @dosu-bot.
tatiana commented 2 months ago

Hi, @dwreeves. I believe there are many great ideas in this proposal. What do you think if we break them into smaller changes so we can split them in the next releases of Cosmos?

We can have a call to talk about this if it would help you. I'd love to see progress in this direction for 1.6, but as the task is described now, we wouldn't have bandwidth.