The CentML compilation backend I am working on wants to wrap the CompiledGraphs forward function (the one returned by get_wrapper) in a torch.fx.GraphModule. This GraphModule would then be pickled and sent from a server to a client.
However, it isn't possible to pickle the lambda/local function returned by get_wrapper. Therefore, I am turning get_wrapper into a class CompiledForwardFunction whose forward function behaves like the wrapper returned by get_wrapper.
Additionally, in order to pickle CompiledForwardFunction, I have defined pickling and unpickling behaviour for CompiledGraph using __getstate__ and __setstate__ respectively. These just call CompiledGraph's existing save and load functions.
The CentML compilation backend I am working on wants to wrap the
CompiledGraph
s forward function (the one returned byget_wrapper
) in atorch.fx.GraphModule
. ThisGraphModule
would then be pickled and sent from a server to a client.However, it isn't possible to pickle the lambda/local function returned by
get_wrapper
. Therefore, I am turningget_wrapper
into a classCompiledForwardFunction
whoseforward
function behaves like thewrapper
returned byget_wrapper
.Additionally, in order to pickle
CompiledForwardFunction
, I have defined pickling and unpickling behaviour forCompiledGraph
using__getstate__
and__setstate__
respectively. These just callCompiledGraph
's existingsave
andload
functions.