Where we see ~24 GB of memory used for a model of size 6.4 GB. Moreover, compiling with IREE through command line shows memory usage more in line with the size of the model. The likely cause of this is that we're passing modules as strings between SharkInference/SharkRunner etc.
def __init__(
self,
mlir_module: str,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
):
self.mlir_module = mlir_module # This exists for the lifetime of the SharkInference object
self.function_name = function_name
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
Strings in python are passed by value, meaning when we pass from download_torch_model to SharkInference to SharkRunner and then accrue a compiled vmfb that adds up to storing ~4 copies of the module.
One solution that might work (haven't tried myself yet) is to just wrap the module in a list (not tuple) which is considered mutable by python and thus "passed by reference." Alternatively we try and do explicit garbage collection, although we'd still be temporarily duplicating the models by passing as a string.
The easiest place to see the excess memory usage is with stable_diffusion, e.g.
Where we see ~24 GB of memory used for a model of size 6.4 GB. Moreover, compiling with IREE through command line shows memory usage more in line with the size of the model. The likely cause of this is that we're passing modules as strings between SharkInference/SharkRunner etc.
Strings in python are passed by value, meaning when we pass from
download_torch_model
toSharkInference
toSharkRunner
and then accrue a compiled vmfb that adds up to storing ~4 copies of the module.One solution that might work (haven't tried myself yet) is to just wrap the module in a list (not tuple) which is considered mutable by python and thus "passed by reference." Alternatively we try and do explicit garbage collection, although we'd still be temporarily duplicating the models by passing as a string.