Exscientia / physicsml

A package for all physics based/related models
MIT License
41 stars 1 forks source link

Cannot compile to torchscript #23

Closed peastman closed 6 months ago

peastman commented 6 months ago

I've trained a tensornet model and saved it to disk. I now want to load it in and compile it to torchscript.

model = load_model('logs/saved_model')
module = torch.jit.script(model)

It fails with this error.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_script.py", line 1351, in script
    return torch.jit._recursive.create_script_class(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 448, in create_script_class
    _compile_and_register_class(type(obj), rcb, qualified_class_name)
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 51, in _compile_and_register_class
    script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: 
Unknown type constructor Type:
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/physicsml/models/tensor_net/supervised/tensor_net_model.py", line 22
    @property
    def _config_builder(self) -> Type[TensorNetModelConfig]:
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        return TensorNetModelConfig

I also tried compiling just the contained module:

module = torch.jit.script(model.module)

That fails with a different exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 477, in create_script_module
    concrete_type = get_module_concrete_type(nn_module, share_types)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 428, in get_module_concrete_type
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 369, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 277, in infer_concrete_type_builder
    overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module, ignored_properties)))
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 643, in get_overload_annotations
    item = getattr(mod, name, None)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 212, in trainer
    raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
RuntimeError: PooledTensorNetModule is not attached to a `Trainer`.

To work around that problem, I tried creating a new Trainer to attach:

import pytorch_lightning as pl
model.module._trainer = pl.Trainer()

That leads to still a different exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 477, in create_script_module
    concrete_type = get_module_concrete_type(nn_module, share_types)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 428, in get_module_concrete_type
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 369, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 333, in infer_concrete_type_builder
    attr_type, inferred = infer_type(name, value)
                          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_recursive.py", line 178, in infer_type
    ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/annotations.py", line 419, in ann_to_type
    the_type = try_ann_to_type(ann, loc)
               ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/annotations.py", line 410, in try_ann_to_type
    return torch.jit._script._recursive_compile_class(ann, loc)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/jit/_script.py", line 1465, in _recursive_compile_class
    rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/_jit_internal.py", line 454, in createResolutionCallbackForClassMethods
    captures.update(get_type_hint_captures(fn))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/site-packages/torch/_jit_internal.py", line 370, in get_type_hint_captures
    src = inspect.getsource(fn)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/inspect.py", line 1262, in getsource
    lines, lnum = getsourcelines(object)
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/inspect.py", line 1244, in getsourcelines
    lines, lnum = findsource(object)
                  ^^^^^^^^^^^^^^^^^^
  File "/Users/peastman/miniconda3/envs/physicsml/lib/python3.11/inspect.py", line 1081, in findsource
    raise OSError('could not get source code')
OSError: could not get source code

This is with Python 3.11.8, PyTorch 2.0.1, and PhysicsML 0.2.0.

wardhaddadin1 commented 6 months ago

Hello Peter!

Yeah it's not as simple as just calling torch.jit.script because there are objects attached to the model and module which are not torch scriptable (like losses etc..).

If you want to a get an openmm torchscipt compatible model, you can do (which is used internally in the openmm-ml interface)

from physicsml.plugins.openmm.load import to_openmm_torchscript

ts_model = to_openmm_torchscript(
    model_path: Optional[str] = None, # path to saved model
    atom_list: Optional[List[int]] = None, # list of atomic numbers of system
    system_path: Optional[str] = None, # or path to a system file (like a pdb file)
    atom_idxs: Optional[List[int]] = None, # indices of atoms to be used in the torchscript model (for mixed systems)
    y_output: Optional[str] = "y_graph_scalars", # the output of the model  
    pbc: Optional[Tuple[bool, bool, bool]] = None, # pbcs
    cell: Optional[List[List[float]]] = None, # cell dims
    output_scaling: Optional[float] = None, # whether to scale the output for different units
    position_scaling: Optional[float] = None, # whether to scale the input positions (to match training data units)
    device: str = 'cpu', # cpu or cuda
    precision: str = '32', # 32 or 64
    torchscipt_path: Optional[str] = None, # whether to write out the model to a torchscipt file as well
) 

Let me know if this works for you!

Best, Ward

peastman commented 6 months ago

I saw that function, but I couldn't figure out how to make it work for my model. For one thing, it expects there to be saved featurization metadata. There isn't for my model, because I directly generated the fully featurized dataset rather than using a featurizer. For another, I have extra node features that need to be passed to the model (specified with physicsml_atom_features in the dataset). I couldn't figure out how to pass those to it.

wardhaddadin1 commented 6 months ago

Yeah for this, you would need to persist the featurisation metadata (otherwise there is no way to know how to reproduce the featurisation). If you do, then just loading with this function should take care of everything (even extra node features).

If you want to go around it, you can basically do what happens internally in that function to make the module torchscriptable:

from molflux.core import load_model

model = load_model("path")

del model.module.losses  # losses are not torchscriptable
del model.module.model_config  # config is not torchscriptable
ts_model = model.module.to_torchscript() 

The ts_model is now just a torchscipt version of the original module and expects the full batch dict (more info here).

Hope that helps! Let me know if you need anything else.

Best, Ward

peastman commented 6 months ago

del model.module.losses # losses are not torchscriptable del model.module.model_config # config is not torchscriptable ts_model = model.module.to_torchscript()

That sounds perfect. I'll try it out and see if it works.

you would need to persist the featurisation metadata (otherwise there is no way to know how to reproduce the featurisation).

The concept of featurization isn't really relevant to what I'm doing. I'm training a model that takes certain inputs and produces certain outputs. How the user chooses to generate the inputs is outside the scope of the model, and it can't be done by PhysicsML.

peastman commented 6 months ago

That works. Thanks!