pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

CUDA_ERROR_ILLEGAL_ADDRESS when using torch profiler #6248

Open yitongh opened 10 months ago

yitongh commented 10 months ago

Using torch.profiler.profile in test/test_train_mp_imagenet.py can result in CUDA_ERROR_ILLEGAL_ADDRESS. git diff test/test_train_mp_imagenet.py

diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py
index 43c4c96..34f8222 100644
--- a/test/test_train_mp_imagenet.py
+++ b/test/test_train_mp_imagenet.py
@@ -289,23 +289,29 @@ def train_imagenet():
   def train_loop_fn(loader, epoch):
     tracker = xm.RateTracker()
     model.train()
-    for step, (data, target) in enumerate(loader):
-      with xp.StepTrace('train_imagenet'):
-        with xp.Trace('build_graph'):
-          optimizer.zero_grad()
-          output = model(data)
-          loss = loss_fn(output, target)
-          loss.backward()
-          if FLAGS.ddp:
-            optimizer.step()
-          else:
-            xm.optimizer_step(optimizer)
-            tracker.add(FLAGS.batch_size)
-          if lr_scheduler:
-            lr_scheduler.step()
-        if step % FLAGS.log_steps == 0:
-          xm.add_step_closure(
-              _train_update, args=(device, step, loss, tracker, epoch, writer))
+    with torch.profiler.profile(
+          activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU
+                              ],
+          schedule=torch.profiler.schedule(wait=3, warmup=2, active=5),
+          on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile")) as prof:
+      for step, (data, target) in enumerate(loader):
+        with xp.StepTrace('train_imagenet'):
+          with xp.Trace('build_graph'):
+            optimizer.zero_grad()
+            output = model(data)
+            loss = loss_fn(output, target)
+            loss.backward()
+            if FLAGS.ddp:
+              optimizer.step()
+            else:
+              xm.optimizer_step(optimizer)
+              tracker.add(FLAGS.batch_size)
+            if lr_scheduler:
+              lr_scheduler.step()
+          if step % FLAGS.log_steps == 0:
+            xm.add_step_closure(
+                _train_update, args=(device, step, loss, tracker, epoch, writer))
+        prof.step()

   def test_loop_fn(loader, epoch):
     total_samples, correct = 0, 0

Command: PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc_per_node 2 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

JackCaoG commented 10 months ago

@vanbasten23 do you have any idea?

vanbasten23 commented 10 months ago

hi @yitongh , what's your cuda version nvcc --version?

yitongh commented 10 months ago

@vanbasten23 My cuda version is 11.8. Driver version is 470.154.

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
vanbasten23 commented 10 months ago

I tried your script in my cuda 12.1 container. I have:

root@xiowei-gpu-1:/ansible# nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
root@xiowei-gpu-1:/ansible# nvidia-smi
Thu Jan  4 01:08:05 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |

but I couldnt reproduce the same error. I got a different error: https://gist.github.com/vanbasten23/fb50968127a5d17f0441753b80bcac5b

Mind sharing your stacktrace?

yitongh commented 10 months ago

My stacktrace is https://gist.github.com/yitongh/b82508236a8e1336f049abebfe7c6e0e BTW, this error seems to occur just after starting the warmup phase in the torch profiler. It could be related to the torch version. I'm using the latest version of PyTorch with commit id f6dfbffb3bb46ada6fe66b5da4f989f9d4d69b3c.

vanbasten23 commented 10 months ago

I wonder if it's a cuda error. With cuda 12.1, I don't see the error and your code runs further.