nod-ai / SHARK-Turbine

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
84 stars 41 forks source link

How to load a model.pth saved with torch.jit.trace(model, input),I didn't find any relevant examples #578

Open muwys518 opened 3 months ago

muwys518 commented 3 months ago

I load torch model with

model = mobilenet_v2(weights='DEFAULT')
traced_model = torch.jit.trace(model, input)
torch.jit.save(traced_model, 'model.pth'))

then I load model.pth with

model = torch.load('model.pth')
export_output = aot.export(model, input)

I get the following error

  File "/home/roger.luo/miniconda3/envs/test-env-pth2mlir/lib/python3.11/site-packages/shark_turbine/aot/exporter.py", line 204, in export
    cm = Exported(context=context, import_to="import")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/roger.luo/miniconda3/envs/test-env-pth2mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 578, in __new__
    do_export(proc_def)
  File "/home/roger.luo/miniconda3/envs/test-env-pth2mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 575, in do_export
    trace.trace_py_func(invoke_with_self)
  File "/home/roger.luo/miniconda3/envs/test-env-pth2mlir/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/tracer.py", line 121, in trace_py_func
    return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/roger.luo/miniconda3/envs/test-env-pth2mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 556, in invoke_with_self
    return proc_def.callable(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/roger.luo/miniconda3/envs/test-env-pth2mlir/lib/python3.11/site-packages/shark_turbine/aot/exporter.py", line 188, in main
    return jittable(mdl.forward)(*args)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/roger.luo/miniconda3/envs/test-env-pth2mlir/lib/python3.11/site-packages/shark_turbine/aot/builtins/jittable.py", line 175, in __init__
    self.function_name = function_name if function_name else wrapped_f.__name__
                                                             ^^^^^^^^^^^^^^^^^^
AttributeError: 'torch._C.ScriptMethod' object has no attribute '__name__'
IanNod commented 3 months ago

aot.export is used to trace the model to generate the IR we use to compile the model, similarly to how you are generating the model.pth. So you would want to pass the same model like this:

model = mobilenet_v2(weights='DEFAULT')
export_output = aot.export(model, input)

The pth file then can be used as external weights used at runtime. An example of compiling a model to enable external weights. can be found here: https://github.com/nod-ai/SHARK-Turbine/blob/55e8703abb4f5ad73b9b95ef5e9e0db20a84b7b4/models/turbine_models/custom_models/stateless_llama.py#L180

and using at runtime here: https://github.com/nod-ai/SHARK-Turbine/blob/55e8703abb4f5ad73b9b95ef5e9e0db20a84b7b4/models/turbine_models/model_runner.py#L20