Open ErnestChan opened 1 year ago
I've narrowed down the issue a bit. It happens when find_unused_parameters=True
for DDP. If I remove all unused parameters from the model and run it with find_unused_parameters=False
, there are no cudagraph_trees compile errors. However if I run the job with find_unused_parameters=True
, there are assertion errors from this line in beginAllocateStreamToPool even though there are no unused parameters in the model.
Still haven't been able to find a small reproducible example though.
We often do multi-task training where not all tasks will run on each iteration, so its hard to get around setting find_unused_parameters=True
.
@ernestchan for the errand data pointer you can set cudagraph_trees_history_recording to true to get better error message.
Just curious - what happens when you set torch._dynamo.config.optimize_ddp
to True ?
cc @yf225 @wconstab for any thoughts.
I'll try to find an OSS model that I can repro this with.
I'll try to find an OSS model that I can repro this with.
Much appreciated 👍
I tried setting cudagraph_trees_history_recording but didn't seem to print more info. I can try it again.
Just curious - what happens when you set torch._dynamo.config.optimize_ddp to True ?
We turned it off because it causes other errors, namely with parameterized modules, and even if those are avoided we get stride assert errors.
I've narrowed down the issue a bit further. When I comment out the code that adds _DDPSink
in nn/parallel/distributed
distributed training works. With _DDPSink
it looks like there are two concurrent threads running the code in cudagraph_trees.py
when backward starts – the main thread and a dummy thread which I assume is started by the C++ autograd backend. This causes the assertion errors since there are two threads in the _use_cuda_memory_pool_manager
context manager.
I added logging for thread ID, stream, device, mem_pool and stack trace on enter and exit of the _use_cuda_memory_pool_manager
context manager. The logs for the last time the dummy thread entered the context:
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] thread=<_DummyThread(Dummy-7, started daemon 23350705448512)> device=0 mem_pool=(0, 2) stream=<torch.cuda.Stream device=cuda:0 cuda_stream=0x5555677224e0> ENTER stac
k trace:
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return user_fn(self, *args)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3233, in backward
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] out = call_compiled_backward()
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3205, in call_compiled_backward
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] out = call_func_with_args(
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1506, in call_func_with_args
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] out = normalize_as_list(f(args))
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return fn(*args, **kwargs)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return fn(*args, **kwargs)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 656, in __call__
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return self.get_current_callable()(inputs)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 685, in run
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return compiled_fn(new_inputs)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 368, in deferred_cudagraphify
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return fn(inputs)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 633, in run
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return model(new_inputs)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1773, in run
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] out = self._run(new_inputs, function_id)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1814, in _run
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return self.run_eager(new_inputs, function_id)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1929, in run_eager
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return node.run(new_inputs)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 630, in run
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] with torch.cuda.device(
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/apcv/shared/conda-envs/ai-1784/lib/python3.10/contextlib.py", line 135, in __enter__
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] return next(self.gen)
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 539, in _use_cuda_memory_pool_manager
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] log.warning(f"{thread=} {device=} {mem_pool=} {stream=} ENTER stack trace:\n{get_stack_trace(thread)}")
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 500, in get_stack_trace
[rank4]:[2023-11-17 20:49:35,293] torch._inductor.cudagraph_trees: [WARNING] frame_summary_list = traceback.extract_stack(current_frames[thread.ident])
At almost the exact same time, the main thread enters the context manager:
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] thread=<_MainThread(MainThread, started 23456244453376)> device=0 mem_pool=(0, 2) stream=<torch.cuda.Stream device=cuda:0 cuda_stream=0x5555677224e0> ENTER stack trace:
<bunch of frames from the application side, ommitted for privacy>
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/repo/training_code.py", line <number>, in train_step
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] amp_scaler.scale(loss_val).backward()
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_tensor.py", line 508, in backward
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] torch.autograd.backward(
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return user_fn(self, *args)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3233, in backward
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] out = call_compiled_backward()
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3205, in call_compiled_backward
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] out = call_func_with_args(
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1506, in call_func_with_args
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] out = normalize_as_list(f(args))
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return fn(*args, **kwargs)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return fn(*args, **kwargs)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 656, in __call__
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return self.get_current_callable()(inputs)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 685, in run
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return compiled_fn(new_inputs)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 368, in deferred_cudagraphify
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return fn(inputs)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 633, in run
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return model(new_inputs)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1773, in run
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] out = self._run(new_inputs, function_id)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1814, in _run
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return self.run_eager(new_inputs, function_id)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1929, in run_eager
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return node.run(new_inputs)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 630, in run
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] with torch.cuda.device(
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/apcv/shared/conda-envs/ai-1784/lib/python3.10/contextlib.py", line 135, in __enter__
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] return next(self.gen)
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 539, in _use_cuda_memory_pool_manager
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] log.warning(f"{thread=} {device=} {mem_pool=} {stream=} ENTER stack trace:\n{get_stack_trace(thread)}")
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] File "/home/user/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 500, in get_stack_trace
[rank4]:[2023-11-17 20:49:35,297] torch._inductor.cudagraph_trees: [WARNING] frame_summary_list = traceback.extract_stack(current_frames[thread.ident])
FYI the error always happens after the first iteration.
even if those are avoided we get stride assert errors.
@ErnestChan We recently merged https://github.com/pytorch/pytorch/pull/114154 (by @jon-chuang) which should fix the stride mismatch error when torch._dynamo.config.optimize_ddp=True
. Would you like to try it again?
@ErnestChan Can you please try with the fix from #114154 for the stride mismatch error and see if that solves the problem when using torch._dynamo.config.optimize_ddp=True
. The PR is part of PT 2.2 release.
@xmfan I wonder if you could dig into this theory a little bit-
i am wondering if find_unused_parameters is setting a state value in the ddp's reducer during each tracing attempt, but then if we do multiple traces and run the compiled programs in different order, perhaps ddp's reducer still expects the parameter order that it saw during the most recent compile attempt?
re torch/nn/parallel/distributed.py: self.reducer.prepare_for_backward(list(_find_tensors(output)))
wondering if the best thing to do is punt on this for graph-break DDP and focus on enabling it for compiled DDP (https://github.com/pytorch/pytorch/pull/110662) -- worth discussing with @fegin
[issue scrubbing] @xmfan I think the last request was for you, but this is currently assigned to @eellison. For either: there hasn't been any activity here for a while; is there still an issue to investigate?
🐛 Describe the bug
I'm running DDP + torch.compile with cudagraph_trees but I'm running into errors related to cuda graphs. This is without DDPOptimizer; I set
torch._dynamo.config.optimize_ddp = False
.For the 8 rank job, 2 of the ranks error due to failing the assert on this line.
The other ranks fail due to:
I am able to run torch.compile with cudagraph_trees on a single GPU, which suggests some interaction with DDP.
I have not yet been able to find a minimal example to reproduce. Would greatly appreciate pointers on how to debug further.
cc: @eellison
Versions
Ubuntu 22.04.1 LTS, Pytorch 2.1.0 built from source, H100 GPU
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mcarilli @ezyang @eellison @penguinwu @chauhang @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @XilunWu @msaroufim @bdhirsh @anijain2305 @zou3519 @peterbell10