NVIDIA / cuda-checkpoint

CUDA checkpoint and restore utility
Other
212 stars 10 forks source link

cuDevicePrimaryCtxGetState() returns error 3 (CUDA_ERROR_NOT_INITIALIZED) in a resumed snapshot under certain circumstances #15

Open paulpopelka opened 2 days ago

paulpopelka commented 2 days ago

We are using cuda-checkpoint to save gpu context into a processes memory and then using our own snapshot facilities to create an executable that will resume execution at the point of the snapshot. The current version of our snapshot tool is not open source.

The application we are using is pytorch. Briefly this is the python program we run:

============================================================================================

from os import getenv
import os.path
import time

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import torch
from transformers import pipeline, set_seed

device = "cuda:0" # if torch.cuda.is_available() else "cpu"
generator = pipeline('text-generation', model='gpt2', pad_token_id=50256, device=device)

set_seed(42)

# wait for them to run cuda-checkpoint --toggle --pid xxxx
print("waiting for dosnap to exist");
while not os.path.exists('dosnap'):
    time.sleep(.1)

if getenv("MAKE_SNAPSHOT") != None:
    from ctypes import CDLL
    kontain = CDLL("libkontain.so")
    print(kontain.snapshot("pytorch", "sentiment", 0))

# let them run cuda-checkpoint --toggle --pid xxxx
print("waiting for dorun to exist")
while not os.path.exists('dorun'):
    time.sleep(.1)

content = "late in the afternoon"
output = generator(content, max_length=30, num_return_sequences=1)
print(output)

============================================================================================

Place MAKE_SNAPSHOT=1 into the environment. Remove the file "dosnap" before running the python program. Run the python program using our snapshot generator/resume code. Once you get the "waiting for dosnap to exist" message, run "cuda-checkpoint --toggle --pid xxxx" Then run "touch dosnap", a snapshot will be generated.

Remove the file "dorun" before resuming the snapshot. After we resume one of our snapshots, wait for the "waiting for dorun to exist" prompt, then run "cuda-checkpoint --toggle --pid xxxx" and then run "touch dorun", the snapshot will run.

The following error will happen.

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA driver error: initialization error
Exception raised from _hasPrimaryContext at ../aten/src/ATen/cuda/detail/CUDAHooks.cpp:67 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fffec981d87 in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fffec93275f in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0xfb3196 (0x7fff2d1b3196 in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10::cuda::MaybeSetDevice(int) + 0xc (0x7fffed29fc4c in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x3042b75 (0x7fff2f242b75 in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x30bf5bd (0x7fff2f2bf5bd in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0xe4c29d (0x7fff2d04c29d in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x30fc6bd (0x7fff2f2fc6bd in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #8: at::_ops::addmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) + 0x86 (0x7fff60e07bd6 in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x42dd013 (0x7fff62cdd013 in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x42de023 (0x7fff62cde023 in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::_ops::addmm::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) + 0x19e (0x7fff60e7b3fe in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x54c6b7 (0x7fffbd14c6b7 in /home/paulp/ai0/km-gpu/tests/km-demo/textgen-gpt2/env/lib64/python3.12/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>

If you look at pytorch source you find the following in aten/src/ATen/cuda/detail/CUDAHooks.cpp:

bool _hasPrimaryContext(DeviceIndex device_index) {
  TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
              "hasPrimaryContext expects a valid device index, but got device_index=", device_index);
  unsigned int ctx_flags;
  // In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird
  // (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero.
  int ctx_is_active = 0;
  AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
  return ctx_is_active == 1;
}

Line 67 is the call to cuDevicePrimaryCtxGetState().

Placing a breakpoint after cuDevicePrimaryCtxGetState() returns we find that the value 3 is returned. That is CUDA_ERROR_NOT_INITIALIZED. The description of that error says cuInit() has not been called.

I placed a breakpoint on cuInit() and reran the test. I found that cuInit() had been called before the failing call to cuDevicePrimaryCtxGetState().

I wondered if there had been other calls to cuDevicePrimaryCtxGetState(), so I placed a breakpoint on cuDevicePrimaryCtxGetState() and resumed the snapshot again. I found there had been approximately 50 successful calls. The call to cuInit() was made and then the failing call to cuDevicePrimaryCtxGetState() was made.

If we perform a single inference before running cuda-checkpoint, then take a snapshot, then resume the snapshot, run cuda-checkpoint, and then perform another inference, the inference following snapshot resume works. There is no failure return from cuDevicePrimaryCtxGetState().

If you remove MAKE_SNAPSHOT from the environment and run the test again (no snapshot is generated, no snapshot to resume) the program works. The call to cuDevicePrimaryCtxGetState() does not fail.

Can you help me understand what is wrong?

jesus-ramos commented 2 days ago

How does your snapshot tool generate the snapshot and restore after the cuda-checkpoint toggle if you can describe the process?

If I'm understanding your flow correctly the failure happens after the app is toggled to the checkpointed state, snapshotted, and then resumed from the snapshot and toggled back to running state after which a subsequent call to cuDevicePrimaryCtxGetState() returns not initialized. Is the snapshot resume done from a cold start or with the currently running app (similar to CRIU --leave-running)?

One thing to note is that NVML support isn't available just yet so the checkpoint will leave some stale references to /dev/nvidiactl /dev/nvidia0...N. Unlikely to be the issue in your snapshot tool but just in case there's some entries in /proc/pid/maps that may not be handled properly but usually pytorch apps only use NVML at the start to query for information and then don't touch it again.

I'll try and test this internally replacing the snapshot tool with CRIU and see if I can replicate it.

jesus-ramos commented 2 days ago

I tested this with internal NVML support and CRIU+CUDA plugin and it looks like it was able to checkpoint/restore properly.

I removed the getenv() check and the 2nd dosnap check, I let the application run until the first "waiting for dosnap". I then issued a criu dump on the process with the cuda plugin which dumped/exited. I then ran the criu restore from the dump, did touch dosnap and let it resume and exit.

Here's the produced output just in case but I got the same results running with/without the CRIU dump.

Truncation was not explicitly activated but max_length is provided a specific value, please use truncation=True to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to truncation. [{'generated_text': 'late in the afternoon for the season opener against Houston.\n\nWash. Tatum suffered his MCL sprain in the 6-2 loss'}]

By the way which driver version are you running?