gnn-tracking / gnn_tracking

Reconstruct billions of particle trajectories with graph neural networks
https://gnn-tracking.rtfd.io/
MIT License
30 stars 13 forks source link

torch.jit.script PyG models #346

Open klieret opened 1 year ago

klieret commented 1 year ago

Javier did it to old IN model

jmduarte commented 1 year ago

If you want to assign to me, I can give it a shot :)

klieret commented 1 year ago

Sounds great! It seems like we can essentially do it after the model definition and just in the training class?

For future reference, this is your commit from the old repo.

klieret commented 1 year ago

So essentially we might just modify this single line here and do

self.model = torch.jit.script( ...)

klieret commented 1 year ago

I tried that very lazily and got

Module 'InteractionNetwork' has no attribute 'inspector' (This attribute exists on the Python module, but we failed to convert Python type: 'torch_geometric.nn.conv.utils.inspector.Inspector' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type Inspector.. Its type was inferred; try adding a type annotation for the attribute.):

so there might be some things to be sorted out.

So essentially we might just modify this single line here and do

Or instead, we could simply JIT the interaction network step itself somewhere (because that's a much simpler object than the compound models built from INs)

klieret commented 11 months ago

I jit compiled the MLPs within the IN for now (which works out of the box), so hopefully that already gives us some boost.

klieret commented 11 months ago

This is the message if I try to just jit the whole IN:

Module 'InteractionNetwork' has no attribute 'inspector' (This attribute exists on the Python module, but we failed to convert Python type: 'torch_geometric.nn.conv.utils.inspector.Inspector' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type Inspector.. Its type was inferred; try adding a type annotation for the attribute.):
  File "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py", line 564

        self._explain = explain
        self.inspector.inspect(self.explain_message, pop_first=True)
        ~~~~~~~~~~~~~~ <--- HERE
        self._user_args = self.inspector.keys(methods).difference(
            self.special_args)

didn't look into this more

klieret commented 9 months ago

I think pytorch no longer works on jit.script/jit.trace. The new way seems to be jit.compile:

https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torch-compile-tutorial

klieret commented 9 months ago

And with lightning: https://lightning.ai/blog/training-compiled-pytorch-2.0-with-pytorch-lightning/