PaulSchweizer / flowpipe

Very simple flow-based programming framework.
MIT License
232 stars 30 forks source link

Iteration node #140

Open Onandon11 opened 4 years ago

Onandon11 commented 4 years ago

Is your feature request related to a problem? Please describe. In our library we have to call a specific subgraph quite frequently (for example 100+ times). Currently you either the following options to achieve this:

  1. Make a unique node for every iteration. This is not following the architecture of the flowpipe, creating duplicate nodes. It also messes up the print(graph)
  2. Make the node capable of handling subinputs and create a connections for every iteration. This causes a lot of extra code, because every node, which you want to iterate sometimes, should be capable of handling subinputs. This also messes up the connection schema of print(graph) and draws a lot of lines.

Describe the solution you'd like I have in mind that in Node.evaluate a check is done if there are inputs that should be iterated over. If so, the node is calling itself that many times. The output of the node is then an iterable and the next node should notice that just like the previous did. For proof-of-principle I subclassed the FunctionNode and created the following (ps.: This is far from prefect but it's as an example).

class IterationPlug(list):
    pass

class Node(FunctionNode):
    def __init__(self, *args, n_iterations=None, disable_iterations=False, **kwargs):
        super().__init__(*args, **kwargs)
        # At instantation of the node you can sepcify the number of iterations.
        # If specified a check if done if the input `IterationPlug` has that many values otherwise
        # raises AssertionError.
        # A `disable_iterations` can also be set to ignore the `IterationPlug` input as iterable.
        self._n_iterations = n_iterations
        self._disable_iterations = disable_iterations

    def evaluate(self):
        """Compute this Node, log it and clean the input Plugs.

        Also push a stat report in the following form containing the Node,
        evaluation time and timestamp the computation started.
        """
        if self.omit:
            self.EVENTS['evaluation-omitted'].emit(self)
            return {}

        self.EVENTS['evaluation-started'].emit(self)

        inputs = {}
        for name, plug in self.inputs.items():
            inputs[name] = plug.value

        # Look if iteration is needed
        is_iteration_detected, iterables = False, {}
        # If itertation is disabled just continue normally
        if not self._disable_iterations:
            for key, value in inputs.items():
                if isinstance(value, IterationPlug):  # Input value is an IterationPlug so iteration is needed
                    is_iteration_detected = True
                    iterables[key] = [{key: x} for x in value]  # Create a list of kwargs to iterate over lateron.
                    if self._n_iterations is not None:
                        # Check the number of itertations if provided
                        assert len(value) == self._n_iterations, f"Length ({len(value)}) of iterable `{key}` is unequal to expected length {self._n_iterations}"
                elif isinstance(value, dict):  # Special case for subinputs
                    for subkey, subvalue in value.items():
                        if isinstance(subvalue, IterationPlug):  # Same as above but a layer deeper.
                            is_iteration_detected = True
                            if key not in iterables:
                                iterables[key] = [{key: {subkey: x}} for x in subvalue]
                            else:
                                iterables[key] = [{key: {**x[key], subkey: y}} for x, y in zip(iterables[key], subvalue)]
                            if self._n_iterations is not None:
                                assert len(subvalue) == self._n_iterations, f"Length ({len(subvalue)}) of iterable `{key}` is unequal to expected length {self._n_iterations}"

        # Compute and redirect the output to the output plugs
        start_time = time.time()
        if is_iteration_detected:  # If iteration is detected
            outputs = {}
            for arguments in zip(*iterables.values()):  # Loop over the list of kwargs in iterables
                [inputs.update(arg) for arg in arguments]  # Update the inputs
                output = self.compute(**inputs) or dict()  # Compute the node
                for key, value in output.items():
                    if key not in outputs:
                        outputs[key] = IterationPlug()
                    # Put the output also in an IterationPlug so the next node also knows that it should iterate.
                    outputs[key].append(value)
        else:
            outputs = self.compute(**inputs) or dict()
        eval_time = time.time() - start_time

        self.stats = {
            'eval_time': eval_time,
            'start_time': start_time
        }

        # all_outputs = self.all_outputs()
        for name, value in outputs.items():
            if '.' in name:
                parent_plug, sub_plug = name.split('.')
                self.outputs[parent_plug][sub_plug].value = value
            else:
                self.outputs[name].value = value

        # Set the inputs clean
        for input_ in self.all_inputs().values():
            input_.is_dirty = False

        self.EVENTS['evaluation-finished'].emit(self)

        return outputs

    def node_repr(self):
        pretty = super().node_repr()
        if not self._disable_iterations:
            pretty = pretty[:-11] + f'n_iter:{self._n_iterations or "??":<2}-+'
        return pretty

def GEAMSNode(*args, **kwargs):
    """Wrap the given function into a Node."""
    cls = kwargs.pop("cls", Node)

    def node(func):
        return cls(func, *args, **kwargs)
    return node

Now you can write the following and it iterates without making more connections or nodes.

import random
from flowpipe import Graph

@GEAMSNode(outputs=['out'])
def foo():
    n_values = random.randint(1, 10)
    return {'out': IterationPlug([random.random() for _ in range(n_values)])}

@GEAMSNode()
def show_value(value):
    print(value)

graph = Graph(name='Iteration concept')
node_foo = foo(graph=graph, name='Creating random number of values')
node_show = show_value(graph=graph, name='Showing the values')

node_foo.outputs['out'] >> node_show.inputs['value']

print(graph)
graph.evaluate()

# +---------------------------------------+          +-------------------------+
# |   Creating random number of values    |          |   Showing the values    |
# |---------------------------------------|          |-------------------------|
# |                                   out o--------->o value<>                 |
# +-----------------------------n_iter:??-+          +---------------n_iter:??-+
# 2020-08-19 09:30:00,047 <      INFO> Evaluating Graph "Iteration concept"
# 0.4673213321721741
# 0.9965826590356521
# 0.40193311849899194
# 0.02730160786984215
# 0.552214249257213
# 0.1996391310232184

Please let me know what you guys think of this feature 😄 !!

neuneck commented 4 years ago

Interesting suggestion! I think it somewhat relates to https://github.com/PaulSchweizer/flowpipe/issues/139 - both issues could be solved easily by transitioning from a batch-processing approach towards a streaming architecture. Still, the option of having nodes iterate over inputs could already allow flowpipe to fulfill many purposes that, for now, it couldn't.

PaulSchweizer commented 4 years ago

This looks very interesting, thanks for the elaborate example @Onandon11 ! I'd be happy to see this implemented. @neuneck, switching to a streaming architecture is a bit more invovled, but as you say this addition might already fulfil a lot of purposes for now

Onandon11 commented 4 years ago

Great :D Shall I make a PR and try to cleanup the example to implement it?

PaulSchweizer commented 4 years ago

Yes, that would be great, thanks @Onandon11