MetOffice / dagrunner

⛔[EXPERIMENTAL] Directed acyclic graph (DAG) runner and tools
BSD 3-Clause "New" or "Revised" License
1 stars 0 forks source link

Skip successor nodes in an execution branch at run-time #6

Closed cpelley closed 1 month ago

cpelley commented 6 months ago

Issue


Using an exception in one branch should skip the dependent nodes in that branch while not affecting other branches, you can use a combination of custom exceptions and a custom delayed function.

Demonstrated that this can indeed work. This will be one of the mechanisms made available to the new framework tooling for controlling graph execution. Caveat: Raising an exception from within a dask node (say a new SkipBranch exception) and capturing this through the .compute of the graph execution only works when utilising the dummy tasks approach (the noop approach) which ensures a no return from a branch termination. Single-threaded scheduling will not support this either way as its sequential approach to execution prohibits this mechanism from being useful.

Demo

import dask
from dask import delayed

class SkipBranch(Exception):
    pass

class SkipComputation(Exception):
    pass

def divide(x, y):
    if y == 0:
        raise SkipBranch("Cannot divide by zero")
    return x / y

def process_data(a, b):
    try:
        result = divide(a, b)
        return result
    except SkipBranch as e:
        raise e  # Re-raise the exception to propagate it up
    except SkipComputation:
        raise SkipBranch("Exception in the current branch, skipping dependent nodes")

# Create a custom delayed function that handles exceptions
def custom_delayed(func):
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except SkipBranch as e:
            return f"Skipped branch: {e}"
        except SkipComputation as e:
            return f"Skipped: {e}"

    return delayed(wrapper)

# Create Dask computation graph
a = delayed(10)
b = delayed(0)
c = delayed(5)

try:
    result1 = custom_delayed(process_data)(a, b)
    result2 = custom_delayed(process_data)(c, a)
    result3 = custom_delayed(process_data)(b, c)
    result4 = custom_delayed(process_data)(result1, result2)
    result5 = custom_delayed(process_data)(result3, result4)

    # Visualize the Dask graph
    dask.visualize([result1, result2, result3, result4, result5])

    # Execute the Dask computation graph
    final_result = result5.compute()
    print("Result:", final_result)
except Exception as e:
    print("Caught exception:", e)
    # Handle the exception and recover if possible
    # Your recovery logic here
cpelley commented 1 month ago

This work was partway completed within the initial commit (special exception handling) and finalised in https://github.com/MetOffice/dagrunner/pull/50 (skip event, a scheduler agnostic approach).