zincware / ZnTrack

Create, visualize, run & benchmark DVC pipelines in Python & Jupyter notebooks.
https://zntrack.readthedocs.io
Apache License 2.0
47 stars 5 forks source link

Improve Graph Writing #785

Open PythonFZ opened 5 months ago

PythonFZ commented 5 months ago

Introduce new abstract methods, for all zntrack.<field> and zntrack.<field>_path.

    def get_zntrack_data(self, instance: "Node") -> dict:
        """Get the data that will be written to the zntrack config file."""
        return {}

    def get_dvc_data(self, instance: "Node") -> dict:
        """Get the data that will be written to the dvc config file."""
        return {}

    def get_params_data(self, instance: "Node") -> dict:
        """Get the data that will be written to the params file."""
        return {}
PythonFZ commented 5 months ago

Example of how this could look like, writing the entire graph config files at the end and not subsequently after each Node.

def node_to_dicts(node) -> tuple[dict, dict, dict]:
    """Create zntrack.json and dvc.yaml entries from a node."""
    dumps = functools.partial(json.dumps, cls=znjson.ZnEncoder)
    module = module_handler(node.__class__)

    zntrack_ = {
        "nwd": json.loads(dumps(node.nwd)),
    }
    params_ = {}
    dvc_ = {
        "cmd": f"zntrack run {module}.{node.__class__.__name__} --name {node.name}",
    }

    for attr in zninit.get_descriptors(Field, self=node):
        zntrack_.update(attr.get_zntrack_data(node))
        params_.update(attr.get_params_data(node))
        dvc_.update(attr.get_dvc_data(node))

    return {node.name: zntrack_}, {"stages": {node.name: dvc_}}, {node.name: params_}

def build(graph):
    nodes = list(graph.nodes)
    data = [node_to_dicts(graph.nodes[node]["value"]) for node in nodes]
    zntrack_, dvc_, params_ = map(list, zip(*data))
    zntrack_ = {k: v for d in zntrack_ for k, v in d.items()}
    dvc_ = {k: v for d in dvc_ for k, v in d.items()}
    params_ = {k: v for d in params_ for k, v in d.items()}

    with open("zntrack.json", "w") as f:
        json.dump(
            zntrack_,
            f,
            indent=4,
            cls=znjson.ZnEncoder.from_converters(
                [CombinedConnectionsConverter, ConnectionConverter], add_default=True
            ),
        )

    with open("dvc.yaml", "w") as f:
        yaml.dump(dvc_, f)

    with open("params.yaml", "w") as f:
        yaml.dump(params_, f)

and then use

with graph:
    a = AddOne(inputs=1, name="a")
    b = AddOne(inputs=2, name="b")
    c = AddTwo(one=a.outputs, two=b.outputs, name="c")
    for idx in range(200):
        c = AddTwo(one=c.outputs, two=c.outputs, name=f"iter_{idx}")

build(graph) # instead of graph.run()