Closed RaulPPelaez closed 1 year ago
Welp, I was not counting on this one...
Python builtin <built-in function _cuda_isCurrentStreamCapturing> is currently not supported in Torchscript:
I have an idea how to circument this
This extension does the trick, but I just managed to get rid of nvcc as a dependency and this would bring it back -.-
import torch
from torch.utils.cpp_extension import load_inline
import torch.nn as nn
cpp_source = '''
#include <torch/script.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>
bool is_stream_capturing() {
at::cuda::CUDAStream current_stream = at::cuda::getCurrentCUDAStream();
cudaStream_t cuda_stream = current_stream.stream();
cudaStreamCaptureStatus capture_status;
cudaError_t err = cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr);
if (err != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(err));
}
return capture_status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive;
}
static auto registry =
torch::RegisterOperators()
.op("torch_extension::is_stream_capturing", &is_stream_capturing);
'''
# Create an inline extension
torch_extension = load_inline(
"is_stream_capturing",
cpp_sources=cpp_source,
functions=["is_stream_capturing"],
with_cuda=True,
verbose=True,
)
@torch.jit.script
def check_stream_capturing():
return torch.ops.torch_extension.is_stream_capturing()
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
is_capturing = (
x.is_cuda
and check_stream_capturing()
)
if not is_capturing:
torch.cuda.synchronize()
y = x*2
return y
x = torch.ones(1).to("cuda").requires_grad_(True)
model = MyModel()
model = torch.jit.script(model)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
#Warmup
y = model(x)
model = torch.cuda.make_graphed_callables(model, (x,), allow_unused_input=True)
y = model(x)
also, I do not know how to ship the extension itself along the model.
As a matter of fact this does not use nvcc, I think everything required is in cudatoolkit!
@raimis please review again so I can merge
The reduction in OutputModel requires knowing the number of batches (computed as batch.max()+1) to be able to produce a tensor with a per-batch property (a.i. the total energy). This operation is inherently blocking and thus cannot be called during CUDA graph capture.
One way to go around this limitation is to make the OutputModel store the number of batches each time it is called. Then, if the module is called in capture mode (which can be detected with
torch.cuda.is_current_stream_capturing()
) the previously stored value is used instead. This makes it so that the number of batches of a captured model cannot be larger than the one used in the call previous to capture. However, this limitation is proper to CUDA graphs and thus I do not consider it an additional limitation.This PR implements this strategy, warning the user about this when they capture the model and raising if the model is being called without warming up first. Possible alternatives: