google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
320 stars 47 forks source link

Implement a pass that partitions an FHE circuit into parallelizable levels #230

Closed j2kun closed 11 months ago

j2kun commented 11 months ago

The gates in each level should be parallelizable. One could achieve this by porting https://github.com/google/fully-homomorphic-encryption/blob/1998f23a8fa97dabac9d1008c57e4a18ed3bee3d/transpiler/graph.h#L117, or else by using some existing MLIR internals to toposort.

j2kun commented 11 months ago

I imagine the right approach here would be to use an existing interface, or define a new interface, and then express the resulting partition in terms of scf.forall, which has parallel semantics.

j2kun commented 11 months ago

OK thinking about this more I have a few ideas to brainstorm (@asraa since you are working on related topics). The main constraint here came from when we did this codegen last time. When we put the entire FHE circuit inside a single function, rustc can't compile it anymore, so the data defining the gates and wires needed to be made into global static data. We organized that into logical levels, and used rayon to parallel iterate over the gates in each layer.

Should we even care about the constraint?

We could argue to ignore the constraint entirely and chalk that up as a limitation of emitting rust. In all likelihood, the longer-term plan will be to lower to LLVM and make API calls to the C-API for tfhe-rust. So that is one option.

How should I record the partition?

We could write an analysis pass that assigns to each gate the integer level it belongs to in the partitioned circuit. We could also implement it as a transform that inserts scf.forall looping ops and put the gate inputs into tensors or memrefs, which would then be directly code-genned as rayon::par_iter. In both cases, the contents of each level would need to be extracted to global data via a full IR walk, because the values generated by each level would not be present for the next level to consume until after that level is completed.

What dialect should I target for this transform?

I could target comb, cggi, or tfhe_rust. Comb is a natural choice, but then the lowerings from comb to cggi need to be aware of whatever mechanism I use to "record" the resulting partition, and future transforms in lower-level dialects need to work around that. E.g., if I write forall to the IR, the lowerings and cggi passes would have to know about forall, and that could add a lot more complexity.

I could also implement some sort of "gate-like" interface to make this easier, though in tfhe_rust with the "shift + add" trick to combine two inputs into a single gate, the link between a gate and its "inputs" is a length-3 if-def chain, and moving the ops would have to require the downstream effect of moving the rest of the if-def chain. This would suggest applying the transform at a higher level of the hierarchy would be better, since in comb there are no extra steps between gates.

Should we implement a more generic scheduler instead?

An alternative to level partitioning with pariter is the more general and likely more performant scheduler that fires off a new thread to execute the gate that is earliest in the toposorted order and has its inputs ready.

After I write all this, I'm thinking it may be easiest to just skip the problem entirely and not do any parallelism in the tfhe-rs emitter, and leave the extra time to focus on scheduling at a lower level.

j2kun commented 11 months ago

Tentatively agreed to NOT do this for now. Closing.

github-actions[bot] commented 9 months ago

[!NOTE] Re-commenting because this issue was closed with unresolved TODOs.

This issue has 1 outstanding TODOs:

This comment was autogenerated by todo-backlinks

j2kun commented 9 months ago

I think this TODO can be deleted, so not reopening.