torchmd / torchmd-net

Training neural network potentials
MIT License
335 stars 75 forks source link

Make OutputModel aware of CUDA graph capturing. #214

Closed RaulPPelaez closed 1 year ago

RaulPPelaez commented 1 year ago

The reduction in OutputModel requires knowing the number of batches (computed as batch.max()+1) to be able to produce a tensor with a per-batch property (a.i. the total energy). This operation is inherently blocking and thus cannot be called during CUDA graph capture.

One way to go around this limitation is to make the OutputModel store the number of batches each time it is called. Then, if the module is called in capture mode (which can be detected with torch.cuda.is_current_stream_capturing()) the previously stored value is used instead. This makes it so that the number of batches of a captured model cannot be larger than the one used in the call previous to capture. However, this limitation is proper to CUDA graphs and thus I do not consider it an additional limitation.

This PR implements this strategy, warning the user about this when they capture the model and raising if the model is being called without warming up first. Possible alternatives:

  1. Adding a new parameter to TorchMD_Net (like n_batches). I think this is more error prone.
  2. Computing and storing batch.max()+1 during capture without including it into the graph. I do not think this is possible sadly.
RaulPPelaez commented 1 year ago

Welp, I was not counting on this one...

Python builtin <built-in function _cuda_isCurrentStreamCapturing> is currently not supported in Torchscript:
raimis commented 1 year ago

I have an idea how to circument this

RaulPPelaez commented 1 year ago

This extension does the trick, but I just managed to get rid of nvcc as a dependency and this would bring it back -.-

import torch
from torch.utils.cpp_extension import load_inline
import torch.nn as nn

cpp_source = '''
#include <torch/script.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>

bool is_stream_capturing() {
  at::cuda::CUDAStream current_stream = at::cuda::getCurrentCUDAStream();
  cudaStream_t cuda_stream = current_stream.stream();
  cudaStreamCaptureStatus capture_status;
  cudaError_t err = cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr);

  if (err != cudaSuccess) {
    throw std::runtime_error(cudaGetErrorString(err));
  }

  return capture_status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive;
}

static auto registry =
  torch::RegisterOperators()
    .op("torch_extension::is_stream_capturing", &is_stream_capturing);
'''

# Create an inline extension
torch_extension = load_inline(
    "is_stream_capturing",
    cpp_sources=cpp_source,
    functions=["is_stream_capturing"],
    with_cuda=True,
    verbose=True,
)

@torch.jit.script
def check_stream_capturing():
    return torch.ops.torch_extension.is_stream_capturing()

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x):
        is_capturing = (
            x.is_cuda
            and check_stream_capturing()
        )
        if not is_capturing:
            torch.cuda.synchronize()
        y = x*2
        return y

x = torch.ones(1).to("cuda").requires_grad_(True)
model = MyModel()
model = torch.jit.script(model)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
    #Warmup
    y = model(x)
    model = torch.cuda.make_graphed_callables(model, (x,), allow_unused_input=True)
    y = model(x)

also, I do not know how to ship the extension itself along the model.

RaulPPelaez commented 1 year ago

As a matter of fact this does not use nvcc, I think everything required is in cudatoolkit!

RaulPPelaez commented 1 year ago

@raimis please review again so I can merge