Instead of returning a Hidet CompiledGraph to the client, the server should return a torch.fx.GraphModule.
Here, it returns a GraphModule that simply wraps around the CompiledGraph's compiled forward function.
(Note: the GraphModule does NOT trace through the forward function and just treats it like a blackbox. Tracing would probably try to optimize the forward function but hidet's forward function has a dynamic control flow which isn't easily traceable)
The GraphModule's graph looks like:
opcode name target args kwargs
----------- ----------- ----------- -------------- --------
placeholder x x () {}
call_module leaf_module leaf_module (x,) {}
output output output (leaf_module,) {}
It expects all arguments to the forward function it wraps to be passed as a tuple.
Instead of returning a Hidet
CompiledGraph
to the client, the server should return atorch.fx.GraphModule
.Here, it returns a
GraphModule
that simply wraps around theCompiledGraph
's compiled forward function. (Note: theGraphModule
does NOT trace through the forward function and just treats it like a blackbox. Tracing would probably try to optimize the forward function but hidet's forward function has a dynamic control flow which isn't easily traceable)The
GraphModule
's graph looks like:It expects all arguments to the forward function it wraps to be passed as a tuple.