tf-encrypted / moose

Secure distributed dataflow framework for encrypted machine learning and data processing
Apache License 2.0
56 stars 15 forks source link

Make AbstractComputation nest-able #1114

Open jvmncs opened 2 years ago

jvmncs commented 2 years ago

Functions written with the eDSL should be callable from within other computations, regardless of whether they've been wrapped with the pm.computation decorator. For example,

@pm.computation
def plus1(x: pm.Argument(alice, dtype=pm.float64):
  with alice:
    one = pm.constant(1, dtype=pm.float64)
    return pm.add(x, one)

@pm.computation
def alice_add():
  with alice:
    x = pm.constant(3, dtype=pm.float64)
    x_plus_one = plus1(x)
  return x_plus_one

if __name__ == "__main__":
  [...]
  runtime.set_default()
  val = alice_add()  # <-- will fail during tracing

When alice_add is called, current behavior would be the following:

One solution for the user is to just drop the pm.computation decorator from plus1, so that it returns Expression no matter what runtime context is around. But this makes it hard for users to use "standard library" computations if they are already decorated with AbstractComputation (which would likely often be the case).

I think the simplest solution here would be to do the following:

Some other options: