aiidateam / aiida-workgraph

Efficiently design and manage flexible workflows with AiiDA, featuring an interactive GUI, checkpoints, provenance tracking, and remote execution capabilities.
https://aiida-workgraph.readthedocs.io/en/latest/
MIT License
9 stars 5 forks source link

`graph_builder` with codes returns "pickling of AiiDA ORM instances is not supported" error #250

Open agoscinski opened 3 weeks ago

agoscinski commented 3 weeks ago
cat_code = orm.InstalledCode(
    computer=orm.load_computer('localhost'),
    filepath_executable='cat',

).store()

@task.graph_builder(outputs = [{"name": "cat_task", "from": "cat_task.stdout"}])
def cat_wg():
    # Create a WorkGraph
    wg = WorkGraph()
    wg.add_task("ShellJob",  command=cat_code, name="cat_task")
    # don't forget to return the `wg`
    return wg

returns

/var/folders/lh/d5j2y3816xg0qffzv9bqlx2c0000gn/T/ipykernel_6492/2223892296.py in <cell line: 7>()
      6 
      7 @task.graph_builder(outputs = [{"name": "cat_task", "from": "cat_task.result"}])
----> 8 def add_multiply():
      9     # Create a WorkGraph
     10     wg = WorkGraph()

~/miniconda3/envs/euroscipy-aiida-demo/lib/python3.10/site-packages/aiida_workgraph/decorator.py in decorator(func)
    589             #
    590             task_type = "graph_builder"
--> 591             tdata = generate_tdata(
    592                 func,
    593                 identifier,

~/miniconda3/envs/euroscipy-aiida-demo/lib/python3.10/site-packages/aiida_workgraph/decorator.py in generate_tdata(func, identifier, inputs, outputs, properties, catalog, task_type, additional_data)
    492         "inputs": _inputs,
    493         "outputs": task_outputs,
--> 494         "executor": serialize_function(func),
    495         "catalog": catalog,
    496     }

~/miniconda3/envs/euroscipy-aiida-demo/lib/python3.10/site-packages/aiida_workgraph/utils/__init__.py in serialize_function(func)
    606         import_statements = ""
    607     return {
--> 608         "executor": pickle.dumps(func),
    609         "type": "function",
    610         "is_pickle": True,

~/miniconda3/envs/euroscipy-aiida-demo/lib/python3.10/site-packages/cloudpickle/cloudpickle.py in dumps(obj, protocol, buffer_callback)
   1477     with io.BytesIO() as file:
   1478         cp = Pickler(file, protocol=protocol, buffer_callback=buffer_callback)
-> 1479         cp.dump(obj)
   1480         return file.getvalue()
   1481 

~/miniconda3/envs/euroscipy-aiida-demo/lib/python3.10/site-packages/cloudpickle/cloudpickle.py in dump(self, obj)
   1243     def dump(self, obj):
   1244         try:
-> 1245             return super().dump(obj)
   1246         except RuntimeError as e:
   1247             if len(e.args) > 0 and "recursion" in e.args[0]:

~/miniconda3/envs/euroscipy-aiida-demo/lib/python3.10/site-packages/aiida/orm/entities.py in __getstate__(self)
    239     def __getstate__(self):
    240         """Prevent an ORM entity instance from being pickled."""
--> 241         raise InvalidOperation('pickling of AiiDA ORM instances is not supported.')
    242 
    243     @super_check

InvalidOperation: pickling of AiiDA ORM instances is not supported.

while using

    wg.add_task("ShellJob",  command='cat', name="cat_task")

works fine

superstar54 commented 3 weeks ago

Indeed the AiiDA ORM data can not be pickled. You need to move the code into the graph builder.


@task.graph_builder(outputs = [{"name": "cat_task", "from": "cat_task.stdout"}])
def cat_wg():
    # load the code here
    cat_code = orm.InstalledCode(
        computer=orm.load_computer('localhost'),
        filepath_executable='cat',
    ).store()
    # Create a WorkGraph
    wg = WorkGraph()
    wg.add_task("ShellJob",  command=cat_code, name="cat_task")
    # don't forget to return the `wg`
    return wg

The above code will work because the cat_code is only loaded when the daemon runs the graph_builder task. Then the cat_code is used as the input of the WorkGraph, which will be serialized to a json format before submitting the WorkGraph to the engine.

We should point out this in the document explictiy.

Edit: or pass the code as an argument of the graph_builder:

# load the code here
cat_code = orm.InstalledCode(
        computer=orm.load_computer('localhost'),
        filepath_executable='cat',
    ).store()

@task.graph_builder(outputs = [{"name": "cat_task", "from": "cat_task.stdout"}])
def cat_wg(code):
    # Create a WorkGraph
    wg = WorkGraph()
    wg.add_task("ShellJob",  command=code, name="cat_task")
    # don't forget to return the `wg`
    return wg