Open klieret opened 1 year ago
If you want to assign to me, I can give it a shot :)
Sounds great! It seems like we can essentially do it after the model definition and just in the training class?
For future reference, this is your commit from the old repo.
So essentially we might just modify this single line here and do
self.model = torch.jit.script( ...)
I tried that very lazily and got
Module 'InteractionNetwork' has no attribute 'inspector' (This attribute exists on the Python module, but we failed to convert Python type: 'torch_geometric.nn.conv.utils.inspector.Inspector' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type Inspector.. Its type was inferred; try adding a type annotation for the attribute.):
so there might be some things to be sorted out.
So essentially we might just modify this single line here and do
Or instead, we could simply JIT the interaction network step itself somewhere (because that's a much simpler object than the compound models built from INs)
I jit compiled the MLPs within the IN for now (which works out of the box), so hopefully that already gives us some boost.
This is the message if I try to just jit the whole IN:
Module 'InteractionNetwork' has no attribute 'inspector' (This attribute exists on the Python module, but we failed to convert Python type: 'torch_geometric.nn.conv.utils.inspector.Inspector' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type Inspector.. Its type was inferred; try adding a type annotation for the attribute.):
File "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py", line 564
self._explain = explain
self.inspector.inspect(self.explain_message, pop_first=True)
~~~~~~~~~~~~~~ <--- HERE
self._user_args = self.inspector.keys(methods).difference(
self.special_args)
didn't look into this more
I think pytorch no longer works on jit.script
/jit.trace
. The new way seems to be jit.compile
:
https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torch-compile-tutorial
And with lightning: https://lightning.ai/blog/training-compiled-pytorch-2.0-with-pytorch-lightning/
Javier did it to old IN model