Closed RaulPPelaez closed 1 year ago
What happens if you compile to torchscript and don't try to use graphs? Does that help training speed?
I was not able to train on a TorchScripted model. I suspect there might be some issue with the old-ish torch lightning version currently used. However upgrading it is dauting #168. When trying to train with an scripted model (this is not particular to TensorNet) I get this error: https://github.com/torchmd/torchmd-net/pull/186#issuecomment-1587513829 I have no idea where it comes from...
I ended up replacing the message passing operations, this did not result in a noticeable speedup (I am just doing what geometry does but without all the boilerplate), but it is more simple code. Anyhow, I believe this round is finished and I would like to merge. Any further ideas can be introduced in another PR. Please review! cc @guillemsimeon @raimis
Is it necessary to do something specific to run inference with all your optimizations? Or they will work out-of-the-box?
You are right, I forgot to comment about that. The optimizations will mainly arise when torch.compiling the TorchMD_Net model:
model = torch.compile(model, backend="inductor", mode="reduce-overhead")
I still cannot use the resulting model for training. With the current conda-forge version of pytorch this call will fail. pytorch-nightly is required. I am using this environment file:
name: torchmdnet
channels:
- "nvidia/label/cuda-11.8.0"
- pytorch-nightly
- conda-forge
dependencies:
- python<3.11
- h5py
- matplotlib
- pip
- flake8
- pytest
- psutil
- ninja
- tqdm
- cuda-libraries-dev<12
- libcurand-dev<12
- cuda-version<12
- cuda-toolkit<12
- cuda-nvcc<12
- gxx_linux-64<12
- pytorch>2.0.1
- pip:
- torch-sparse==0.6.17
- torch-cluster==1.6.1
- torch-geometric==2.3.1
- torch-scatter==2.1.1
- pytorch-lightning==1.6.3
- torchmetrics==0.11.4
- mdtraj
- moleculekit
The issue with it is that nnpops is not installable via pip, so I cannot add it here because it is incompatible with the torch nightly package There is also another issue in this PR: NVCC became a requirement, so that must be installed. There is the cudatoolkit-dev package in coda-forge (takes forever to install), and also the cuda-nvcc package in the cuda channel. What do we want to do about this?
To make the model fully compatible with CUDA Graph, I need replace scatter
with reduction:
diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py
index d3d7e1b..29edf12 100644
--- a/torchmdnet/models/output_modules.py
+++ b/torchmdnet/models/output_modules.py
@@ -25,7 +25,8 @@ class OutputModel(nn.Module, metaclass=ABCMeta):
return
def reduce(self, x, batch):
- return scatter(x, batch, dim=0, reduce=self.reduce_op)
+ # return scatter(x, batch, dim=0, reduce=self.reduce_op)
+ return torch.sum(x)
def post_reduce(self, x):
return x
Of course, this is assumes that there is one molecule and the reduction is summing.
You can give scatter the argument dim_size with the number of different batches and this will be CUDA-graph compatible. By default, scatter computes dim_size=batch.max()+1, which causes a sync. In the current code, just give it a 1 if you only have one molecule.
In general, not having access to the number of batches is a pain for CUDA graphs. If there is only one then the batch argument can be None, which can be detected and leveraged without syncs. Otherwise unless we add an argument "num_molecules" I see no way out of a sync.
OK! This is enough to fix graphing, but we need a general solution:
diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py
index d3d7e1b..45e66eb 100644
--- a/torchmdnet/models/output_modules.py
+++ b/torchmdnet/models/output_modules.py
@@ -25,7 +25,7 @@ class OutputModel(nn.Module, metaclass=ABCMeta):
return
def reduce(self, x, batch):
- return scatter(x, batch, dim=0, reduce=self.reduce_op)
+ return scatter(x, batch, dim=0, dim_size=1, reduce=self.reduce_op)
def post_reduce(self, x):
return x
Great! Did you found out anything about the ungraphable operation in backpropagation?
With https://github.com/torchmd/torchmd-net/pull/195#discussion_r1325742792 and https://github.com/torchmd/torchmd-net/pull/195#issuecomment-1719195035, I can graph forward and backward passes, but the let me check in the full pipeline.
Awesome! masked_scatter has been giving me a hard time from the start, I can believe that being the offending function. Thanks for the replacement!
The CUDA Graphs work, except OutputModel
, which will be addressed in #214 or other.
This PR is an effort to optimize TensorNet. Note that I merged here also all the changes in #186.
I have focused on inference of small molecules (such as Alanine, 22 atoms), in which we saw the CUDA kernel launching overhead was tremendous: My strategy was to simply modify what was needed for
torch.compile(backend="inductor", mode="reduce-overhead")
to put as much as possible into a CUDA graph.I replaced some blocking operations and I was able to convince TorchDynamo to put the bulk of the forward pass into a CUDA graph (orange is a graph): Zooming in, I am having trouble with making TorchDynamo understand the neighbor kernels: Alas, TorchDynamo refuses to graph the backwards pass. But I know it is possible to graph it, since I can do so manually.
Anyhow, I believe this is a problem that will solve itself in time as TorchDynamo improves its support for things like Dynamic shapes and extensions. And even today, the speedup is impressive.
I am not sure how to apply optimizations such as CUDA graphs for training, where shapes are dynamic. However, in those situations the changes sin this PR will just not improve things as much. TorchDynamo refuses to even process the model when used during training, complaining about dynamic shapes.
You will notice the benchmark tries to compile the model by doing this:
This line fails miserably when using the current environment, I had to use pytorch-nightly (which provides version 2.1.0):
Which made me install the rest of the torch friends with pip:
Luckily these compile just fine with pip. With a little luck after this https://github.com/conda-forge/pytorch-cpu-feedstock/pull/172 is merged we can update the env and stay in conda-forge.