hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.
https://hidet.org
Apache License 2.0
648 stars 52 forks source link

[Dynamo] Refactor get_wrapper and pickling compiled graph #438

Closed destefy closed 6 months ago

destefy commented 6 months ago

The CentML compilation backend I am working on wants to wrap the CompiledGraphs forward function (the one returned by get_wrapper) in a torch.fx.GraphModule. This GraphModule would then be pickled and sent from a server to a client.

However, it isn't possible to pickle the lambda/local function returned by get_wrapper. Therefore, I am turning get_wrapper into a class CompiledForwardFunction whose forward function behaves like the wrapper returned by get_wrapper.

Additionally, in order to pickle CompiledForwardFunction, I have defined pickling and unpickling behaviour for CompiledGraph using __getstate__ and __setstate__ respectively. These just call CompiledGraph's existing save and load functions.